From 69c9830d288d5b8d7f0abe7c8a65a4c966580a49 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 17 Sep 2015 00:48:57 -0700 Subject: [PATCH 001/896] [MINOR] [CORE] Fixes minor variable name typo Author: Cheng Lian Closes #8784 from liancheng/typo-fix. --- .../apache/spark/serializer/GenericAvroSerializerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69..87f25e7245e1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -76,9 +76,9 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) - val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) - assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } } From c633ed3260140f1288f326acc4d7a10dcd2e27d5 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:43:59 -0700 Subject: [PATCH 002/896] [SPARK-10284] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.tuning Author: Yu ISHIKAWA Closes #8694 from yu-iskw/SPARK-10284. --- python/pyspark/ml/tuning.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index cae778869e9c..ab5621f45c72 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,6 +18,7 @@ import itertools import numpy as np +from pyspark import since from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model from pyspark.ml.util import keyword_only @@ -47,11 +48,14 @@ class ParamGridBuilder(object): True >>> all([m in expected for m in output]) True + + .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {} + @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. @@ -60,6 +64,7 @@ def addGrid(self, param, values): return self + @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. @@ -73,6 +78,7 @@ def baseOn(self, *args): return self + @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified @@ -104,6 +110,8 @@ class CrossValidator(Estimator): >>> cvModel = cv.fit(dataset) >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -142,6 +150,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self._set(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): @@ -150,6 +159,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. @@ -157,12 +167,14 @@ def setEstimator(self, value): self._paramMap[self.estimator] = value return self + @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) + @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. @@ -170,12 +182,14 @@ def setEstimatorParamMaps(self, value): self._paramMap[self.estimatorParamMaps] = value return self + @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) + @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. @@ -183,12 +197,14 @@ def setEvaluator(self, value): self._paramMap[self.evaluator] = value return self + @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. @@ -196,6 +212,7 @@ def setNumFolds(self, value): self._paramMap[self.numFolds] = value return self + @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. @@ -231,7 +248,15 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ if extra is None: extra = dict() newCV = Params.copy(self, extra) @@ -246,6 +271,8 @@ def copy(self, extra=None): class CrossValidatorModel(Model): """ Model from k-fold cross validation. + + .. versionadded:: 1.4.0 """ def __init__(self, bestModel): @@ -256,6 +283,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) + @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid From 29bf8aa5a51fdd8c2600533297f991e14fa27c03 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:45:20 -0700 Subject: [PATCH 003/896] [SPARK-10283] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.regression Author: Yu ISHIKAWA Closes #8693 from yu-iskw/SPARK-10283. --- python/pyspark/ml/regression.py | 65 +++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a9503608b7f2..21d454f9003b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -62,6 +63,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ @keyword_only @@ -81,6 +84,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True): @@ -96,13 +100,31 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) + @since("1.4.0") + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + @since("1.4.0") + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + class LinearRegressionModel(JavaModel): """ Model fitted by LinearRegression. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -110,6 +132,7 @@ def weights(self): return self._call_java("weights") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -162,6 +185,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -193,6 +218,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -209,6 +235,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -216,6 +243,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -225,13 +253,19 @@ def getImpurity(self): @inherit_doc class DecisionTreeModel(JavaModel): + """Abstraction for Decision Tree models. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def numNodes(self): """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property + @since("1.5.0") def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") @@ -242,8 +276,13 @@ def __repr__(self): @inherit_doc class TreeEnsembleModels(JavaModel): + """Represents a tree ensemble model. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @@ -256,6 +295,8 @@ def __repr__(self): class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. + + .. versionadded:: 1.4.0 """ @@ -282,6 +323,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -336,6 +379,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, @@ -353,6 +397,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -360,12 +405,14 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -373,12 +420,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. @@ -386,12 +435,14 @@ def setNumTrees(self, value): self._paramMap[self.numTrees] = value return self + @since("1.4.0") def getNumTrees(self): """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) + @since("1.4.0") def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. @@ -399,6 +450,7 @@ def setFeatureSubsetStrategy(self, value): self._paramMap[self.featureSubsetStrategy] = value return self + @since("1.4.0") def getFeatureSubsetStrategy(self): """ Gets the value of featureSubsetStrategy or its default value. @@ -409,6 +461,8 @@ def getFeatureSubsetStrategy(self): class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. + + .. versionadded:: 1.4.0 """ @@ -435,6 +489,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -481,6 +537,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -498,6 +555,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -505,12 +563,14 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -518,12 +578,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. @@ -531,6 +593,7 @@ def setStepSize(self, value): self._paramMap[self.stepSize] = value return self + @since("1.4.0") def getStepSize(self): """ Gets the value of stepSize or its default value. @@ -541,6 +604,8 @@ def getStepSize(self): class GBTRegressionModel(TreeEnsembleModels): """ Model fitted by GBTRegressor. + + .. versionadded:: 1.4.0 """ From 0ded87a4d49d4484e202bd2ec781821b57b5882c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:47:21 -0700 Subject: [PATCH 004/896] [SPARK-10281] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.clustering Author: Yu ISHIKAWA Closes #8691 from yu-iskw/SPARK-10281. --- python/pyspark/ml/clustering.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index cb4c16e25a7a..7bb8ab94e17d 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -26,8 +27,11 @@ class KMeansModel(JavaModel): """ Model fitted by KMeans. + + .. versionadded:: 1.5.0 """ + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @@ -55,6 +59,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -88,6 +94,7 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ @@ -99,6 +106,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -110,12 +118,14 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of `k` """ return self.getOrDefault(self.k) + @since("1.5.0") def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. @@ -130,12 +140,14 @@ def setInitMode(self, value): self._paramMap[self.initMode] = value return self + @since("1.5.0") def getInitMode(self): """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) + @since("1.5.0") def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. @@ -147,6 +159,7 @@ def setInitSteps(self, value): self._paramMap[self.initSteps] = value return self + @since("1.5.0") def getInitSteps(self): """ Gets the value of `initSteps` From 39b44cb52eb225469eb4ccdf696f0bc6405b9184 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:48:45 -0700 Subject: [PATCH 005/896] [SPARK-10278] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.tree Author: Yu ISHIKAWA Closes #8685 from yu-iskw/SPARK-10278. --- python/pyspark/mllib/tree.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 372b86a7c95d..0001b60093a6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ import random -from pyspark import SparkContext, RDD +from pyspark import SparkContext, RDD, since from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -30,6 +30,11 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + """TreeEnsembleModel + + .. versionadded:: 1.3.0 + """ + @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -45,12 +50,14 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.3.0") def numTrees(self): """ Get number of trees in ensemble. """ return self.call("numTrees") + @since("1.3.0") def totalNumNodes(self): """ Get total number of nodes, summed over all trees in the @@ -62,6 +69,7 @@ def __repr__(self): """ Summary of model """ return self._java_model.toString() + @since("1.3.0") def toDebugString(self): """ Full model """ return self._java_model.toDebugString() @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. note:: Experimental A decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ + @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. @@ -90,16 +101,23 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.1.0") def numNodes(self): + """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() + @since("1.1.0") def depth(self): + """Get depth of tree. + E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ return self._java_model.depth() def __repr__(self): """ summary of model. """ return self._java_model.toString() + @since("1.2.0") def toDebugString(self): """ full model. """ return self._java_model.toDebugString() @@ -115,6 +133,8 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ @classmethod @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod + @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @classmethod + @since("1.1.0") def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a random forest model. + + .. versionadded:: 1.2.0 """ @classmethod @@ -252,6 +276,8 @@ class RandomForest(object): Learning algorithm for a random forest model for classification or regression. + + .. versionadded:: 1.2.0 """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod + @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod + @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a gradient-boosted tree model. + + .. versionadded:: 1.3.0 """ @classmethod @@ -431,6 +461,8 @@ class GradientBoostedTrees(object): Learning algorithm for a gradient boosted trees model for classification or regression. + + .. versionadded:: 1.3.0 """ @classmethod @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod + @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod + @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): From 4a0b56e8dbb3713b16e58738201d838ffc4b258b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:00 -0700 Subject: [PATCH 006/896] [SPARK-10279] [MLLIB] [PYSPARK] [DOCS] Add @since annotation to pyspark.mllib.util Author: Yu ISHIKAWA Closes #8689 from yu-iskw/SPARK-10279. --- python/pyspark/mllib/util.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 10a1e4b3eb0f..39bc6586dd58 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -23,7 +23,7 @@ xrange = range basestring = str -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -32,6 +32,8 @@ class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. + + .. versionadded:: 1.0.0 """ @staticmethod @@ -69,6 +71,7 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod + @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -123,6 +126,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod + @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. @@ -147,6 +151,7 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod + @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. @@ -172,6 +177,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) @staticmethod + @since("1.5.0") def appendBias(data): """ Returns a new vector with `1.0` (bias) appended to @@ -186,6 +192,7 @@ def appendBias(data): return _convert_to_vector(np.append(vec.toArray(), 1.0)) @staticmethod + @since("1.5.0") def loadVectors(sc, path): """ Loads vectors saved using `RDD[Vector].saveAsTextFile` @@ -197,6 +204,8 @@ def loadVectors(sc, path): class Saveable(object): """ Mixin for models and transformers which may be saved as files. + + .. versionadded:: 1.3.0 """ def save(self, sc, path): @@ -222,9 +231,13 @@ class JavaSaveable(Saveable): """ Mixin for models that provide save() through their Scala implementation. + + .. versionadded:: 1.3.0 """ + @since("1.3.0") def save(self, sc, path): + """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): @@ -235,6 +248,8 @@ def save(self, sc, path): class Loader(object): """ Mixin for classes which can load saved models from files. + + .. versionadded:: 1.3.0 """ @classmethod @@ -256,6 +271,8 @@ class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala implementation. + + .. versionadded:: 1.3.0 """ @classmethod @@ -280,15 +297,21 @@ def _load_java(cls, sc, path): return java_obj.load(sc._jsc.sc(), path) @classmethod + @since("1.3.0") def load(cls, sc, path): + """Load a model from the given path.""" java_model = cls._load_java(sc, path) return cls(java_model) class LinearDataGenerator(object): - """Utils for generating linear data""" + """Utils for generating linear data. + + .. versionadded:: 1.5.0 + """ @staticmethod + @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ @@ -311,6 +334,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, xVariance, int(nPoints), int(seed), float(eps))) @staticmethod + @since("1.5.0") def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ From c74d38fd8faf8cba981cf934341d24b9a3167025 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:46 -0700 Subject: [PATCH 007/896] [SPARK-10274] [MLLIB] Add @since annotation to pyspark.mllib.fpm Author: Yu ISHIKAWA Closes #8665 from yu-iskw/SPARK-10274. --- python/pyspark/mllib/fpm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b1..bdabba9602a8 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -19,7 +19,7 @@ from numpy import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -41,8 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + + .. versionadded:: 1.4.0 """ + @since("1.4.0") def freqItemsets(self): """ Returns the frequent itemsets of this model. @@ -55,9 +58,12 @@ class FPGrowth(object): .. note:: Experimental A Parallel FP-growth algorithm to mine frequent itemsets. + + .. versionadded:: 1.4.0 """ @classmethod + @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. @@ -74,6 +80,8 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ Represents an (items, freq) tuple. + + .. versionadded:: 1.4.0 """ From 268088b899e6e165e746aed87840d47bfaf50c43 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:51:19 -0700 Subject: [PATCH 008/896] [SPARK-10282] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.recommendation Author: Yu ISHIKAWA Closes #8692 from yu-iskw/SPARK-10282. --- python/pyspark/ml/recommendation.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b06099ac0aee..ec5748a1cfe9 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -80,6 +81,8 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=3.19...) >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -122,6 +125,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): @@ -137,6 +141,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem def _create_model(self, java_model): return ALSModel(java_model) + @since("1.4.0") def setRank(self, value): """ Sets the value of :py:attr:`rank`. @@ -144,12 +149,14 @@ def setRank(self, value): self._paramMap[self.rank] = value return self + @since("1.4.0") def getRank(self): """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) + @since("1.4.0") def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. @@ -157,12 +164,14 @@ def setNumUserBlocks(self, value): self._paramMap[self.numUserBlocks] = value return self + @since("1.4.0") def getNumUserBlocks(self): """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) + @since("1.4.0") def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. @@ -170,12 +179,14 @@ def setNumItemBlocks(self, value): self._paramMap[self.numItemBlocks] = value return self + @since("1.4.0") def getNumItemBlocks(self): """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) + @since("1.4.0") def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. @@ -183,6 +194,7 @@ def setNumBlocks(self, value): self._paramMap[self.numUserBlocks] = value self._paramMap[self.numItemBlocks] = value + @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. @@ -190,12 +202,14 @@ def setImplicitPrefs(self, value): self._paramMap[self.implicitPrefs] = value return self + @since("1.4.0") def getImplicitPrefs(self): """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) + @since("1.4.0") def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. @@ -203,12 +217,14 @@ def setAlpha(self, value): self._paramMap[self.alpha] = value return self + @since("1.4.0") def getAlpha(self): """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) + @since("1.4.0") def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. @@ -216,12 +232,14 @@ def setUserCol(self, value): self._paramMap[self.userCol] = value return self + @since("1.4.0") def getUserCol(self): """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) + @since("1.4.0") def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. @@ -229,12 +247,14 @@ def setItemCol(self, value): self._paramMap[self.itemCol] = value return self + @since("1.4.0") def getItemCol(self): """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) + @since("1.4.0") def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. @@ -242,12 +262,14 @@ def setRatingCol(self, value): self._paramMap[self.ratingCol] = value return self + @since("1.4.0") def getRatingCol(self): """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) + @since("1.4.0") def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. @@ -255,6 +277,7 @@ def setNonnegative(self, value): self._paramMap[self.nonnegative] = value return self + @since("1.4.0") def getNonnegative(self): """ Gets the value of nonnegative or its default value. @@ -265,14 +288,18 @@ def getNonnegative(self): class ALSModel(JavaModel): """ Model fitted by ALS. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def rank(self): """rank of the matrix factorization model""" return self._call_java("rank") @property + @since("1.4.0") def userFactors(self): """ a DataFrame that stores user factors in two columns: `id` and @@ -281,6 +308,7 @@ def userFactors(self): return self._call_java("userFactors") @property + @since("1.4.0") def itemFactors(self): """ a DataFrame that stores item factors in two columns: `id` and From e51345e1e04e439827a07c95887d14ba38333057 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 17 Sep 2015 09:17:43 -0700 Subject: [PATCH 009/896] [SPARK-10077] [DOCS] [ML] Add package info for java of ml/feature Should be the same as SPARK-7808 but use Java for the code example. It would be great to add package doc for `spark.ml.feature`. Author: Holden Karau Closes #8740 from holdenk/SPARK-10077-JAVA-PACKAGE-DOC-FOR-SPARK.ML.FEATURE. --- .../apache/spark/ml/feature/package-info.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java new file mode 100644 index 000000000000..c22d2e0cd2d9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Feature transformers + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which + * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., + * {@link org.apache.spark.feature.HashingTF}. + * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the + * transformation requires some aggregated information of the dataset, e.g., document + * frequencies in {@link org.apache.spark.ml.feature.IDF}. + * For those feature transformers, calling {@link org.apache.spark.ml.Estimator#fit} is required to + * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply + * transformation. + * The transformation is usually done by appending new columns to the input + * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * {@link org.apache.spark.ml.Pipeline} can be used to chain feature transformers, and + * {@link org.apache.spark.ml.feature.VectorAssembler} can be used to combine multiple feature + * transformations, for example: + * + *
+ * 
+ *   import java.util.Arrays;
+ *
+ *   import org.apache.spark.api.java.JavaRDD;
+ *   import static org.apache.spark.sql.types.DataTypes.*;
+ *   import org.apache.spark.sql.types.StructType;
+ *   import org.apache.spark.sql.DataFrame;
+ *   import org.apache.spark.sql.RowFactory;
+ *   import org.apache.spark.sql.Row;
+ *
+ *   import org.apache.spark.ml.feature.*;
+ *   import org.apache.spark.ml.Pipeline;
+ *   import org.apache.spark.ml.PipelineStage;
+ *   import org.apache.spark.ml.PipelineModel;
+ *
+ *  // a DataFrame with three columns: id (integer), text (string), and rating (double).
+ *  StructType schema = createStructType(
+ *    Arrays.asList(
+ *      createStructField("id", IntegerType, false),
+ *      createStructField("text", StringType, false),
+ *      createStructField("rating", DoubleType, false)));
+ *  JavaRDD rowRDD = jsc.parallelize(
+ *    Arrays.asList(
+ *      RowFactory.create(0, "Hi I heard about Spark", 3.0),
+ *      RowFactory.create(1, "I wish Java could use case classes", 4.0),
+ *      RowFactory.create(2, "Logistic regression models are neat", 4.0)));
+ *  DataFrame df = jsql.createDataFrame(rowRDD, schema);
+ *  // define feature transformers
+ *  RegexTokenizer tok = new RegexTokenizer()
+ *    .setInputCol("text")
+ *    .setOutputCol("words");
+ *  StopWordsRemover sw = new StopWordsRemover()
+ *    .setInputCol("words")
+ *    .setOutputCol("filtered_words");
+ *  HashingTF tf = new HashingTF()
+ *    .setInputCol("filtered_words")
+ *    .setOutputCol("tf")
+ *    .setNumFeatures(10000);
+ *  IDF idf = new IDF()
+ *    .setInputCol("tf")
+ *    .setOutputCol("tf_idf");
+ *  VectorAssembler assembler = new VectorAssembler()
+ *    .setInputCols(new String[] {"tf_idf", "rating"})
+ *    .setOutputCol("features");
+ *
+ *  // assemble and fit the feature transformation pipeline
+ *  Pipeline pipeline = new Pipeline()
+ *    .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler});
+ *  PipelineModel model = pipeline.fit(df);
+ *
+ *  // save transformed features with raw data
+ *  model.transform(df)
+ *    .select("id", "text", "rating", "features")
+ *    .write().format("parquet").save("/output/path");
+ * 
+ * 
+ * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see + * scikit-learn.preprocessing + */ +package org.apache.spark.ml.feature; From 2a508df20d03b3d4a3c05b65fb02d849bc080ef9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 09:21:21 -0700 Subject: [PATCH 010/896] [SPARK-10459] [SQL] Do not need to have ConvertToSafe for PythonUDF JIRA: https://issues.apache.org/jira/browse/SPARK-10459 As mentioned in the JIRA, `PythonUDF` actually could process `UnsafeRow`. Specially, the rows in `childResults` in `BatchPythonEvaluation` will be projected to a `MutableRow`. So I think we can enable `canProcessUnsafeRows` for `BatchPythonEvaluation` and get rid of redundant `ConvertToSafe`. Author: Liang-Chi Hsieh Closes #8616 from viirya/pyudf-unsafe. --- .../scala/org/apache/spark/sql/execution/pythonUDFs.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80..d0411da6fdf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -337,6 +337,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) From c88bb5df94f9696677c3a429472114bc66f32a52 Mon Sep 17 00:00:00 2001 From: "yangping.wu" Date: Thu, 17 Sep 2015 09:52:40 -0700 Subject: [PATCH 011/896] [SPARK-10660] Doc describe error in the "Running Spark on YARN" page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the Configuration section, the **spark.yarn.driver.memoryOverhead** and **spark.yarn.am.memoryOverhead**‘s default value should be "driverMemory * 0.10, with minimum of 384" and "AM memory * 0.10, with minimum of 384" respectively. Because from Spark 1.4.0, the **MEMORY_OVERHEAD_FACTOR** is set to 0.1.0, not 0.07. Author: yangping.wu Closes #8797 from 397090770/SparkOnYarnDocError. --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d1244323edff..3a961d245f3d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,14 +211,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.driver.memoryOverhead - driverMemory * 0.07, with minimum of 384 + driverMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead - AM memory * 0.07, with minimum of 384 + AM memory * 0.10, with minimum of 384 Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. From 136c77d8bbf48f7c45dd7c3fbe261a0476f455fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 10:02:15 -0700 Subject: [PATCH 012/896] [SPARK-10642] [PYSPARK] Fix crash when calling rdd.lookup() on tuple keys JIRA: https://issues.apache.org/jira/browse/SPARK-10642 When calling `rdd.lookup()` on a RDD with tuple keys, `portable_hash` will return a long. That causes `DAGScheduler.submitJob` to throw `java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Integer`. Author: Liang-Chi Hsieh Closes #8796 from viirya/fix-pyrdd-lookup. --- python/pyspark/rdd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9ef60a7e2c84..ab5aab1e115f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -84,7 +84,7 @@ def portable_hash(x): h ^= len(x) if h == -1: h = -2 - return h + return int(h) return hash(x) @@ -2192,6 +2192,9 @@ def lookup(self, key): [42] >>> sorted.lookup(1024) [] + >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey() + >>> list(rdd2.lookup(('a', 'b'))[0]) + ['c'] """ values = self.filter(lambda kv: kv[0] == key).values() From 81b4db374dd61b6f1c30511c70b6ab2a52c68faa Mon Sep 17 00:00:00 2001 From: Josiah Samuel Date: Thu, 17 Sep 2015 10:18:21 -0700 Subject: [PATCH 013/896] [SPARK-10172] [CORE] disable sort in HistoryServer webUI This pull request is to address the JIRA SPARK-10172 (History Server web UI gets messed up when sorting on any column). The content of the table gets messed up due to the rowspan attribute of the table data(cell) during sorting. The current table sort library used in SparkUI (sorttable.js) doesn't support/handle cells(td) with rowspans. The fix will disable the table sort in the web UI, when there are jobs listed with multiple attempts. Author: Josiah Samuel Closes #8506 from josiahsams/SPARK-10172. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba124..b347cb3be69f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } From 36d8b278d82e788bf583e8438fac524d0023311d Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 17 Sep 2015 10:25:18 -0700 Subject: [PATCH 014/896] [SPARK-10531] [CORE] AppId is set as AppName in status rest api Verify it manually. Author: Jeff Zhang Closes #8688 from zjffdu/SPARK-10531. --- .../main/scala/org/apache/spark/SparkContext.scala | 1 + .../spark/deploy/history/FsHistoryProvider.scala | 9 ++++----- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 11 ++++++----- .../scala/org/apache/spark/ui/UISeleniumSuite.scala | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a2f34eafa2c3..9c3218719f7f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -521,6 +521,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a5755eac3639..8eb2ba1e8683 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -146,16 +146,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26904d39a9be..d518e92133aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -944,7 +944,7 @@ private[deploy] class Master( val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { replayBus.replay(logInput, eventLogFile, maybeTruncated) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index d8b90568b7b9..99085ada9f0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -56,6 +56,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -75,9 +77,8 @@ private[spark] class SparkUI private ( def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,12 +95,12 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 22e30ecaf053..18eec7da9763 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -658,6 +658,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/test/" + path) + new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } From e0dc2bc232206d2f4da4278502c1f88babc8b55a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 17 Sep 2015 11:05:30 -0700 Subject: [PATCH 015/896] [SPARK-10650] Clean before building docs The [published docs for 1.5.0](http://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/streaming/) have a bunch of test classes in them. The only way I can reproduce this is to `test:compile` before running `unidoc`. To prevent this from happening again, I've added a clean before doc generation. Author: Michael Armbrust Closes #8787 from marmbrus/testsInDocs. --- docs/_plugins/copy_api_dirs.rb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 15ceda11a8a8..01718d98dffe 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,12 +26,15 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl clean compile unidoc` puts "Moving back into docs dir." cd("docs") + puts "Removing old docs" + puts `rm -rf api` + # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. source = "../target/scala-2.10/unidoc" From aad644fbe29151aec9004817d42e4928bdb326f3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 17 Sep 2015 11:14:52 -0700 Subject: [PATCH 016/896] [SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai Closes #8788 from yhuai/udafConversion. --- .../sql/catalyst/CatalystTypeConverters.scala | 7 +- .../spark/sql/RandomDataGenerator.scala | 16 ++- .../spark/sql/execution/aggregate/udaf.scala | 37 +++++- .../org/apache/spark/sql/QueryTest.scala | 21 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 11 ++ .../execution/AggregationQuerySuite.scala | 108 +++++++++++++++++- 6 files changed, 188 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 966623ed017b..f25591794abd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -138,8 +138,13 @@ object CatalystTypeConverters { private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + + override def toScala(catalystValue: Any): Any = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 4025cbcec101..e48395028e39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -108,7 +108,21 @@ object RandomDataGenerator { arr }) case BooleanType => Some(() => rand.nextBoolean()) - case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case DateType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt) + } + Some(generator) case TimestampType => val generator = () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d43d3dd9ffaa..1114fe6552bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < getters.length) { getters(i) = dataTypes(i) match { + case NullType => + (row: InternalRow, ordinal: Int) => null + case BooleanType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) @@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + case DateType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { + case NullType => + (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + case b: BooleanType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { @@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + // To make it work with UnsafeRow, we cannot use setNullAt. + // Please see the comment of UnsafeRow's setDecimal. + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + + case DateType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { - row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case TimestampType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) } else { row.setNullAt(ordinal) } @@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } @@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF( } } + private[this] lazy val outputToCatalystConverter: Any => Any = { + CatalystTypeConverters.createToCatalystConverter(dataType) + } + // This buffer is only used at executor side. private[this] var inputAggregateBuffer: InputAggregationBuffer = null @@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF( override def eval(buffer: InternalRow): Any = { evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(evalAggregateBuffer) + outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer)) } override def toString: String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index cada03e9ac6b..e3c5a426671d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -115,19 +115,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 46d87843dfa4..7992fd59ff4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { assert(new MyDenseVectorUDT().typeName === "mydensevector") assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } + + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a73b1bd52c09..24b1846923c7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,13 +17,55 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { + + def inputSchema: StructType = schema + + def bufferSchema: StructType = schema + + def dataType: DataType = schema + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + (0 until schema.length).foreach { i => + buffer.update(i, null) + } + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0) && input.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer.update(i, input.get(i)) + } + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer1.update(i, buffer2.get(i)) + } + } + } + + def evaluate(buffer: Row): Any = { + Row.fromSeq(buffer.toSeq) + } +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } } + + test("udaf with all data types") { + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + // Right now, we will use SortBasedAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // UnsafeRow as the aggregation buffer. While, dataTypes will trigger + // SortBasedAggregate to use a safe row as the aggregation buffer. + Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + // The schema used for data generator. + val schemaForGenerator = StructType(fields) + // The schema used for the DataFrame df. + val schema = StructType(StructField("id", IntegerType) +: fields) + + logInfo(s"Testing schema: ${schema.treeString}") + + val udaf = new ScalaAggregateFunction(schema) + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) + val data = (1 to 50).map { i => + dataGenerator.apply() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 1) + val df = sqlContext.createDataFrame(rdd, schema) + + val allColumns = df.schema.fields.map(f => col(f.name)) + val expectedAnaswer = + data + .find(r => r.getInt(0) == 50) + .getOrElse(fail("A row with id 50 should be the expected answer.")) + checkAnswer( + df.groupBy().agg(udaf(allColumns: _*)), + // udaf returns a Row as the output value. + Row(expectedAnaswer) + ) + } + } } class SortBasedAggregationQuerySuite extends AggregationQuerySuite { From 64743870f23bffb8d96dcc8a0181c1452782a151 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 11:24:38 -0700 Subject: [PATCH 017/896] [SPARK-10394] [ML] Make GBTParams use shared stepSize ```GBTParams``` has ```stepSize``` as learning rate currently. ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation. Author: Yanbo Liang Closes #8552 from yanboliang/spark-10394. --- .../org/apache/spark/ml/tree/treeParams.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d29f5253c9c3..42e74ce6d2c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -365,17 +365,7 @@ private[ml] object RandomForestParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** @group setParam */ + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group setParam + */ def setStepSize(value: Double): this.type = set(stepSize, value) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) + override def validateParams(): Unit = { + require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( + getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + + s"but it given invalid value $getStepSize.") + } /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( From f1c911552cf5d0d60831c79c1881016293aec66c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Sep 2015 11:40:24 -0700 Subject: [PATCH 018/896] [SPARK-10657] Remove SCP-based Jenkins log archiving As of https://issues.apache.org/jira/browse/SPARK-7561, we no longer need to use our custom SCP-based mechanism for archiving Jenkins logs on the master machine; this has been superseded by the use of a Jenkins plugin which archives the logs and provides public links to view them. Per shaneknapp, we should remove this log syncing mechanism if it is no longer necessary; removing the need to SCP from the Jenkins workers to the masters is a desired step as part of some larger Jenkins infra refactoring. Author: Josh Rosen Closes #8793 from JoshRosen/remove-jenkins-ssh-to-master. --- dev/run-tests-jenkins | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3be78575e70f..d3b05fa6df0c 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -116,39 +116,6 @@ function post_message () { fi } -function send_archived_logs () { - echo "Archiving unit tests logs..." - - local log_files=$( - find .\ - -name "unit-tests.log" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.wrong" - ) - - if [ -z "$log_files" ]; then - echo "> No log files found." >&2 - else - local log_archive="unit-tests-logs.tar.gz" - echo "$log_files" | xargs tar czf ${log_archive} - - local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} - local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) - local scp_status="$?" - - if [ "$scp_status" -ne 0 ]; then - echo "Failed to send archived unit tests logs to Jenkins master." >&2 - echo "> scp_status: ${scp_status}" >&2 - echo "> scp_output: ${scp_output}" >&2 - else - echo "> Send successful." - fi - - rm -f ${log_archive} - fi -} - # post start message { start_message="\ @@ -244,8 +211,6 @@ done test_result_note=" * This patch **fails $failing_test**." fi - - send_archived_logs } # post end message From 4fbf3328692e876f39ea78494510f9d9c5a53f15 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 14:09:06 -0700 Subject: [PATCH 019/896] [SPARK-9698] [ML] Add RInteraction transformer for supporting R-style feature interactions This is a pre-req for supporting the ":" operator in the RFormula feature transformer. Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7987 from ericl/interaction. --- .../apache/spark/ml/feature/Interaction.scala | 278 ++++++++++++++++++ .../spark/ml/feature/InteractionSuite.scala | 165 +++++++++++ 2 files changed, 443 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 000000000000..9194763fb32f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is + * produced. + * + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. + */ +@Experimental +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } + + override def transform(dataset: DataFrame): DataFrame = { + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) + + def interactFunc = udf { row: Row => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var featureIndex = row.length - 1 + while (featureIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentEncoder = featureEncoders(featureIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + featureIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + } + } + dataset.select( + col("*"), + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + } + + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) + case _ => + 1 // numeric feature + } + } + features.map { f => + val numFeatures = f.dataType match { + case _: NumericType | BooleanType => + Array(getNumFeatures(Attribute.fromStructField(f))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( + throw new SparkException("Vector attributes must be defined for interaction.")) + attrs.map(getNumFeatures).toArray + } + new FeatureEncoder(numFeatures) + }.toArray + } + + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(f) + encodedFeatureAttrs(Seq(attr), None) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) + } + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs + } else { + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), featureAttrs.toArray) + } + + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) + parts.flatten.mkString("_") + } + + inputAttrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + override def validateParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. + * + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. + */ +private[ml] class FeatureEncoder(numFeatures: Array[Int]) { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + + /** The size of the output vector. */ + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + var i = 1 + while (i < arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 + } + arr + } + + /** + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { + case d: Double => + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { + assert( + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) + } else { + f(outputOffsets(i), v) + } + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 000000000000..2beb62ca0823 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("feature encoder") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed + } + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } + } + + test("numeric interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) + } + + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } +} From 0f5ef6dfa67a068606aff8ea9d1addfce73446eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 17 Sep 2015 19:16:34 -0700 Subject: [PATCH 020/896] [SPARK-10674] [TESTS] Increase timeouts in SaslIntegrationSuite. 1s seems to trigger too many times on the jenkins build boxes, so increase the timeout and cross fingers. Author: Marcelo Vanzin Closes #8802 from vanzin/SPARK-10674 and squashes the following commits: 3c93117 [Marcelo Vanzin] Use java 7 syntax. d667d1b [Marcelo Vanzin] [SPARK-10674] [tests] Increase timeouts in SaslIntegrationSuite. --- .../spark/network/sasl/SaslIntegrationSuite.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 5cb0e4d4a645..c393a5e1e681 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -56,6 +56,11 @@ import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private final static long TIMEOUT_MS = 10_000; + static TransportServer server; static TransportConf conf; static TransportContext context; @@ -102,7 +107,7 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg } @@ -131,7 +136,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], 1000); + client.sendRpcSync(new byte[13], TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -139,7 +144,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -217,12 +222,12 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), 10000); + client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); + byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); long streamId = stream.streamId; From 98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 21:37:10 -0700 Subject: [PATCH 021/896] [SPARK-8518] [ML] Log-linear models for survival analysis [Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang Closes #8611 from yanboliang/spark-8518. --- .../ml/regression/AFTSurvivalRegression.scala | 449 ++++++++++++++++++ .../AFTSurvivalRegressionSuite.scala | 311 ++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala new file mode 100644 index 000000000000..5b25db651f56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Params for accelerated failure time (AFT) regression. + */ +private[regression] trait AFTSurvivalRegressionParams extends Params + with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter + with HasTol with HasFitIntercept { + + /** + * Param for censor column name. + * The value of this column could be 0 or 1. + * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. + * @group param + */ + @Since("1.6.0") + final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") + + /** @group getParam */ + @Since("1.6.0") + def getCensorCol: String = $(censorCol) + setDefault(censorCol -> "censor") + + /** + * Param for quantile probabilities array. + * Values of the quantile probabilities array should be in the range [0, 1]. + * @group param + */ + @Since("1.6.0") + final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, + "quantileProbabilities", "quantile probabilities array", + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + + /** @group getParam */ + @Since("1.6.0") + def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + + /** Checks whether the input has quantile probabilities array. */ + protected[regression] def hasQuantileProbabilities: Boolean = { + isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + } + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Fit a parametric survival regression model named accelerated failure time (AFT) model + * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]]) + * based on the Weibull distribution of the survival time. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("aftSurvReg")) + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setCensorCol(value: String): this.type = set(censorCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + @Since("1.6.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("1.6.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, + * and put it in an RDD with strong types. + */ + protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } + } + + @Since("1.6.0") + override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + val instances = extractAFTPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val costFun = new AFTCostFun(instances, $(fitIntercept)) + val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + + val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + /* + The weights vector has three parts: + the first element: Double, log(sigma), the log of scale parameter + the second element: Double, intercept of the beta parameter + the third to the end elements: Doubles, regression coefficients vector of the beta parameter + */ + val initialWeights = Vectors.zeros(numFeatures + 2) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeights.toBreeze.toDenseVector) + + val weights = { + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + throw new SparkException(msg) + } + + state.x.toArray.clone() + } + + if (handlePersistence) instances.unpersist() + + val coefficients = Vectors.dense(weights.slice(2, weights.length)) + val intercept = weights(1) + val scale = math.exp(weights(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + copyValues(model.setParent(this)) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model produced by [[AFTSurvivalRegression]]. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegressionModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val intercept: Double, + @Since("1.6.0") val scale: Double) + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + @Since("1.6.0") + def predictQuantiles(features: Vector): Vector = { + require(hasQuantileProbabilities, + "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") + // scale parameter for the Weibull distribution of lifetime + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + // shape parameter for the Weibull distribution of lifetime + val k = 1 / scale + val quantiles = $(quantileProbabilities).map { + q => lambda * math.exp(math.log(-math.log(1 - q)) / k) + } + Vectors.dense(quantiles) + } + + @Since("1.6.0") + def predict(features: Vector): Double = { + math.exp(BLAS.dot(coefficients, features) + intercept) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val predictUDF = udf { features: Vector => predict(features) } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + .setParent(parent) + } +} + +/** + * AFTAggregator computes the gradient and loss for a AFT loss function, + * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * + * The loss function and likelihood function under the AFT model based on: + * Lawless, J. F., Statistical Models and Methods for Lifetime Data, + * New York: John Wiley & Sons, Inc. 2003. + * + * Two AFTAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n, + * with possible right-censoring, the likelihood function under the AFT model is given as + * {{{ + * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} + * }}} + * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not. + * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function + * assumes the form + * {{{ + * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ + * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] + * }}} + * Where S_{0}(\epsilon_{i}) is the baseline survivor function, + * and f_{0}(\epsilon_{i}) is corresponding density function. + * + * The most commonly used log-linear survival regression method is based on the Weibull + * distribution of the survival time. The Weibull distribution for lifetime corresponding + * to extreme value distribution for log of the lifetime, + * and the S_{0}(\epsilon) function is + * {{{ + * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) + * }}} + * the f_{0}(\epsilon_{i}) function is + * {{{ + * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) + * }}} + * The log-likelihood function for Weibull distribution of lifetime is + * {{{ + * \iota(\beta,\sigma)= + * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] + * }}} + * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, + * the loss function we use to optimize is -\iota(\beta,\sigma). + * The gradient functions for \beta and \log\sigma respectively are + * {{{ + * \frac{\partial (-\iota)}{\partial \beta}= + * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} + * }}} + * {{{ + * \frac{\partial (-\iota)}{\partial (\log\sigma)}= + * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] + * }}} + * @param weights The log of scale parameter, the intercept and + * regression coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean) + extends Serializable { + + // beta is the intercept and regression coefficients to the covariates + private val beta = weights.slice(1, weights.length) + // sigma is the scale parameter of the AFT model + private val sigma = math.exp(weights(0)) + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientLogSigmaSum = 0.0 + + def count: Long = totalCnt + + def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt + + // Here we optimize loss function over beta and log(sigma) + def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), + gradientBetaSum/totalCnt.toDouble) + + /** + * Add a new training data to this AFTAggregator, and update the loss and gradient + * of the objective function. + * + * @param data The AFTPoint representation for one data point to be added into this aggregator. + * @return This AFTAggregator object. + */ + def add(data: AFTPoint): this.type = { + + // TODO: Don't create a new xi vector each time. + val xi = if (fitIntercept) { + Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze + } else { + Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze + } + val ti = data.label + val delta = data.censor + val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + + lossSum += math.log(sigma) * delta + lossSum += (math.exp(epsilon) - delta * epsilon) + + // Sanity check (should never occur): + assert(!lossSum.isInfinity, + s"AFTAggregator loss sum is infinity. Error for unknown reason.") + + gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma + gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + + totalCnt += 1 + this + } + + /** + * Merge another AFTAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other AFTAggregator to be merged. + * @return This AFTAggregator object. + */ + def merge(other: AFTAggregator): this.type = { + if (totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + gradientBetaSum += other.gradientBetaSum + gradientLogSigmaSum += other.gradientLogSigmaSum + } + this + } +} + +/** + * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. + * It returns the loss and gradient at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ +private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + + val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + seqOp = (c, v) => (c, v) match { + case (aggregator, instance) => aggregator.add(instance) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + (aftAggregator.loss, aftAggregator.gradient) + } +} + +/** + * Class that represents the (features, label, censor) of a data point. + * + * @param features List of features for this data point. + * @param label Label for this data point. + * @param censor Indicator of the event has occurred or not. If the value is 1, it means + * the event has occurred i.e. uncensored; otherwise censored. + */ +private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { + require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 000000000000..ca7140a45ea6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, DataFrame} + +class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } +} From d009da2f5c803f3b7344c96abbfcf3ecef2f5ad2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:05:20 -0700 Subject: [PATCH 022/896] [SPARK-10682] [GRAPHX] Remove Bagel test suites. Bagel has been deprecated and we haven't done any changes to it. There is no need to run those tests. This should speed up tests by 1 min. Author: Reynold Xin Closes #8807 from rxin/SPARK-10682. --- bagel/src/test/resources/log4j.properties | 27 ----- .../org/apache/spark/bagel/BagelSuite.scala | 113 ------------------ 2 files changed, 140 deletions(-) delete mode 100644 bagel/src/test/resources/log4j.properties delete mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae9209..000000000000 --- a/bagel/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74..000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} From 93c7650ab60a839a9cbe8b4ea1d5eda93e53ebe0 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Thu, 17 Sep 2015 22:25:24 -0700 Subject: [PATCH 023/896] [SPARK-9522] [SQL] SparkSubmit process can not exit if kill application when HiveThriftServer was starting When we start HiveThriftServer, we will start SparkContext first, then start HiveServer2, if we kill application while HiveServer2 is starting then SparkContext will stop successfully, but SparkSubmit process can not exit. Author: linweizhong Closes #7853 from Sephiroth-Lin/SPARK-9522. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9c3218719f7f..ebd8e946ee7a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -97,7 +97,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index dd9fef9206d0..a0643cec0fb7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -93,6 +93,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) From 9a56dcdf7f19c9f7f913a2ce9bc981cb43a113c5 Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Thu, 17 Sep 2015 22:42:46 -0700 Subject: [PATCH 024/896] docs/running-on-mesos.md: state default values in default column This PR simply uses the default value column for defaults. Author: Felix Bechstein Closes #8810 from felixb/fix_mesos_doc. --- docs/running-on-mesos.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 247e6ecfbdb8..1814fb32ed8a 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -332,21 +332,21 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.principal - Framework principal to authenticate to Mesos + (none) Set the principal with which Spark framework will use to authenticate with Mesos. spark.mesos.secret - Framework secret to authenticate to Mesos + (none)/td> Set the secret with which Spark framework will use to authenticate with Mesos. spark.mesos.role - Role for the Spark framework + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. @@ -354,7 +354,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints - Attribute based constraints to be matched against when accepting resource offers. + (none) Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes.
    From 74d8f7dda82c3a16348f3ff22da83203e0b7f708 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:46:13 -0700 Subject: [PATCH 025/896] Added tag to documentation. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 1814fb32ed8a..330c159c67bc 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -346,7 +346,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.role - * + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. From e3b5d6cb29e0f983fcc55920619e6433298955f5 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 18 Sep 2015 00:43:02 -0700 Subject: [PATCH 026/896] [SPARK-10684] [SQL] StructType.interpretedOrdering need not to be serialized Kryo fails with buffer overflow even with max value (2G). {noformat} org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 1 Serialization trace: containsChild (org.apache.spark.sql.catalyst.expressions.BoundReference) child (org.apache.spark.sql.catalyst.expressions.SortOrder) array (scala.collection.mutable.ArraySeq) ordering (org.apache.spark.sql.catalyst.expressions.InterpretedOrdering) interpretedOrdering (org.apache.spark.sql.types.StructType) schema (org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema). To avoid this, increase spark.kryoserializer.buffer.max value. at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:263) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:240) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) {noformat} Author: navis.ryu Closes #8808 from navis/SPARK-10684. --- .../main/scala/org/apache/spark/sql/types/StructType.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d8968ef80639..b29cf22dcb58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -305,7 +305,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } - private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) + @transient + private[sql] lazy val interpretedOrdering = + InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { From 20fd35dfd1ac402b622604e7bbedcc53a580b0a2 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 18 Sep 2015 08:22:38 -0700 Subject: [PATCH 027/896] [SPARK-10451] [SQL] Prevent unnecessary serializations in InMemoryColumnarTableScan Many of the fields in InMemoryColumnar scan and InMemoryRelation can be made transient. This reduces my 1000ms job to abt 700 ms . The task size reduces from 2.8 mb to ~1300kb Author: Yash Datta Closes #8604 from saucam/serde. --- .../columnar/InMemoryColumnarTableScan.scala | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 66d429bc0619..d7e145f9c2bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -48,10 +48,10 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan, + @transient child: SparkPlan, tableName: Option[String])( - private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null, + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private var _statistics: Statistics = null, private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { @@ -62,7 +62,7 @@ private[sql] case class InMemoryRelation( _batchStats } - val partitionStatistics = new PartitionStatistics(output) + @transient val partitionStatistics = new PartitionStatistics(output) private def computeSizeInBytes = { val sizeOfRow: Expression = @@ -196,7 +196,7 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], predicates: Seq[Expression], - relation: InMemoryRelation) + @transient relation: InMemoryRelation) extends LeafNode { override def output: Seq[Attribute] = attributes @@ -205,7 +205,7 @@ private[sql] case class InMemoryColumnarTableScan( // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - val buildFilter: PartialFunction[Expression, Expression] = { + @transient val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -268,16 +268,23 @@ private[sql] case class InMemoryColumnarTableScan( readBatches.setValue(0) } - relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => - val partitionFilter = newPredicate( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), - relation.partitionStatistics.schema) + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val relOutput = relation.output + val buffers = relation.cachedColumnBuffers + + buffers.mapPartitions { cachedBatchIterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + schema) // Find the ordinals and data types of the requested columns. If none are requested, use the // narrowest (the field with minimum default element size). val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { val (narrowestOrdinal, narrowestDataType) = - relation.output.zipWithIndex.map { case (a, ordinal) => + relOutput.zipWithIndex.map { case (a, ordinal) => ordinal -> a.dataType } minBy { case (_, dataType) => ColumnType(dataType).defaultSize @@ -285,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { attributes.map { a => - relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip } @@ -296,7 +303,7 @@ private[sql] case class InMemoryColumnarTableScan( // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => ColumnAccessor( - relation.output(batchColumnIndex).dataType, + relOutput(batchColumnIndex).dataType, ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) } @@ -328,7 +335,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" From 35e8ab939000d4a1a01c1af4015c25ff6f4013a3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Sep 2015 09:53:52 -0700 Subject: [PATCH 028/896] [SPARK-10615] [PYSPARK] change assertEquals to assertEqual As ```assertEquals``` is deprecated, so we need to change ```assertEquals``` to ```assertEqual``` for existing python unit tests. Author: Yanbo Liang Closes #8814 from yanboliang/spark-10615. --- python/pyspark/ml/tests.py | 16 +-- python/pyspark/mllib/tests.py | 162 +++++++++++++++--------------- python/pyspark/sql/tests.py | 18 ++-- python/pyspark/streaming/tests.py | 2 +- 4 files changed, 99 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b892318f50bd..648fa8858fba 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -182,7 +182,7 @@ def test_params(self): self.assertEqual(testParams.getMaxIter(), 10) testParams.setMaxIter(100) self.assertTrue(testParams.isSet(maxIter)) - self.assertEquals(testParams.getMaxIter(), 100) + self.assertEqual(testParams.getMaxIter(), 100) self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -195,7 +195,7 @@ def test_params(self): testParams._setDefault(seed=41) testParams.setSeed(43) - self.assertEquals( + self.assertEqual( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", @@ -264,23 +264,23 @@ def test_ngram(self): self.assertEqual(ngram0.getInputCol(), "input") self.assertEqual(ngram0.getOutputCol(), "output") transformedDF = ngram0.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default - self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getInputCol(), "input") transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["panda"]) + self.assertEqual(transformedDF.head().output, ["panda"]) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) - self.assertEquals(stopWordRemover.getInputCol(), "input") - self.assertEquals(stopWordRemover.getStopWords(), stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a"]) + self.assertEqual(transformedDF.head().output, ["a"]) class HasInducedError(Params): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 636f9a06cab7..96cf13495aa9 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -166,13 +166,13 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.]]) arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEquals(10.0, sv.dot(dv)) + self.assertEqual(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEquals(30.0, dv.dot(dv)) + self.assertEqual(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEquals(30.0, lst.dot(dv)) + self.assertEqual(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEquals(7.0, sv.dot(arr)) + self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -181,27 +181,27 @@ def test_squared_distance(self): lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEquals(15.0, _squared_distance(sv, dv)) - self.assertEquals(25.0, _squared_distance(sv, lst)) - self.assertEquals(20.0, _squared_distance(dv, lst)) - self.assertEquals(15.0, _squared_distance(dv, sv)) - self.assertEquals(25.0, _squared_distance(lst, sv)) - self.assertEquals(20.0, _squared_distance(lst, dv)) - self.assertEquals(0.0, _squared_distance(sv, sv)) - self.assertEquals(0.0, _squared_distance(dv, dv)) - self.assertEquals(0.0, _squared_distance(lst, lst)) - self.assertEquals(25.0, _squared_distance(sv, lst1)) - self.assertEquals(3.0, _squared_distance(sv, arr)) - self.assertEquals(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(hash(v1), hash(v2)) - self.assertEquals(hash(v1), hash(v3)) - self.assertEquals(hash(v2), hash(v3)) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) self.assertFalse(hash(v1) == hash(v4)) self.assertFalse(hash(v2) == hash(v4)) @@ -212,8 +212,8 @@ def test_eq(self): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(v1, v2) - self.assertEquals(v1, v3) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) @@ -238,13 +238,13 @@ def test_conversion(self): def test_sparse_vector_indexing(self): sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[-1], 2) + self.assertEqual(sv[-2], 0) + self.assertEqual(sv[-4], 0) for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: @@ -255,7 +255,7 @@ def test_matrix_indexing(self): expected = [[0, 6], [1, 8], [4, 10]] for i in range(3): for j in range(2): - self.assertEquals(mat[i, j], expected[i][j]) + self.assertEqual(mat[i, j], expected[i][j]) def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -308,11 +308,11 @@ def test_sparse_matrix(self): # Test sparse matrix creation. sm1 = SparseMatrix( 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEquals(sm1.numRows, 3) - self.assertEquals(sm1.numCols, 4) - self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) self.assertTrue( repr(sm1), 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') @@ -325,13 +325,13 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1[i, j]) + self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() - self.assertEquals(sm1.numRows, smnew.numRows) - self.assertEquals(sm1.numCols, smnew.numCols) + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) self.assertTrue(array_equal(sm1.values, smnew.values)) @@ -339,11 +339,11 @@ def test_sparse_matrix(self): sm1t = SparseMatrix( 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], isTransposed=True) - self.assertEquals(sm1t.numRows, 3) - self.assertEquals(sm1t.numCols, 4) - self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) expected = [ [3, 2, 0, 0], @@ -352,18 +352,18 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertEqual(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) def test_dense_matrix_is_transposed(self): mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEquals(mat1, mat) + self.assertEqual(mat1, mat) expected = [[0, 4], [1, 6], [3, 9]] for i in range(3): for j in range(2): - self.assertEquals(mat1[i, j], expected[i][j]) + self.assertEqual(mat1[i, j], expected[i][j]) self.assertTrue(array_equal(mat1.toArray(), expected)) sm = mat1.toSparse() @@ -412,8 +412,8 @@ def test_kmeans(self): ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", initializationSteps=7, epsilon=1e-4) - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_kmeans_deterministic(self): from pyspark.mllib.clustering import KMeans @@ -443,8 +443,8 @@ def test_gmm(self): clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, maxIterations=10, seed=56) labels = clusters.predict(data).collect() - self.assertEquals(labels[0], labels[1]) - self.assertEquals(labels[2], labels[3]) + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture @@ -456,7 +456,7 @@ def test_gmm_deterministic(self): clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEquals(round(c1, 7), round(c2, 7)) + self.assertEqual(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -711,18 +711,18 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_to_vector(lil)) - self.assertEquals(sv, _convert_to_vector(lil.tocsc())) - self.assertEquals(sv, _convert_to_vector(lil.tocoo())) - self.assertEquals(sv, _convert_to_vector(lil.tocsr())) - self.assertEquals(sv, _convert_to_vector(lil.todok())) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(l): return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEquals(sv, serialize(lil)) - self.assertEquals(sv, serialize(lil.tocsc())) - self.assertEquals(sv, serialize(lil.tocsr())) - self.assertEquals(sv, serialize(lil.todok())) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix @@ -730,7 +730,7 @@ def test_dot(self): lil[1, 0] = 1 lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEquals(10.0, dv.dot(lil)) + self.assertEqual(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix @@ -739,8 +739,8 @@ def test_squared_distance(self): lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, dv.squared_distance(lil)) - self.assertEquals(15.0, sv.squared_distance(lil)) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" @@ -759,8 +759,8 @@ def test_clustering(self): self.scipy_matrix(3, {2: 1.1}) ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -984,12 +984,12 @@ def test_word2vec_setters(self): .setNumIterations(10) \ .setSeed(1024) \ .setMinCount(3) - self.assertEquals(model.vectorSize, 2) + self.assertEqual(model.vectorSize, 2) self.assertTrue(model.learningRate < 0.02) - self.assertEquals(model.numPartitions, 2) - self.assertEquals(model.numIterations, 10) - self.assertEquals(model.seed, 1024) - self.assertEquals(model.minCount, 3) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) def test_word2vec_get_vectors(self): data = [ @@ -1002,7 +1002,7 @@ def test_word2vec_get_vectors(self): ["a"] ] model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEquals(len(model.getVectors()), 3) + self.assertEqual(len(model.getVectors()), 3) class StandardScalerTests(MLlibTestCase): @@ -1044,8 +1044,8 @@ def test_model_params(self): """Test that the model params are set correctly""" stkm = StreamingKMeans() stkm.setK(5).setDecayFactor(0.0) - self.assertEquals(stkm._k, 5) - self.assertEquals(stkm._decayFactor, 0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) # Model not set yet. self.assertIsNone(stkm.latestModel()) @@ -1053,9 +1053,9 @@ def test_model_params(self): stkm.setInitialCenters( centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEquals( + self.assertEqual( stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) def test_accuracy_for_single_center(self): """Test that parameters obtained are correct for a single center.""" @@ -1070,7 +1070,7 @@ def test_accuracy_for_single_center(self): self.ssc.start() def condition(): - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True self._eventually(condition, catch_assertions=True) @@ -1114,7 +1114,7 @@ def test_trainOn_model(self): def condition(): finalModel = stkm.latestModel() self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True self._eventually(condition, catch_assertions=True) @@ -1141,7 +1141,7 @@ def update(rdd): self.ssc.start() def condition(): - self.assertEquals(result, [[0], [1], [2], [3]]) + self.assertEqual(result, [[0], [1], [2], [3]]) return True self._eventually(condition, catch_assertions=True) @@ -1263,7 +1263,7 @@ def test_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(models), len(input_batches)) + self.assertEqual(len(models), len(input_batches)) return True # We want all batches to finish for this test. @@ -1297,7 +1297,7 @@ def test_predictions(self): self.ssc.start() def condition(): - self.assertEquals(len(true_predicted), len(input_batches)) + self.assertEqual(len(true_predicted), len(input_batches)) return True self._eventually(condition, catch_assertions=True) @@ -1400,7 +1400,7 @@ def test_parameter_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(model_weights), len(batches)) + self.assertEqual(len(model_weights), len(batches)) return True # We want all batches to finish for this test. @@ -1433,7 +1433,7 @@ def test_prediction(self): self.ssc.start() def condition(): - self.assertEquals(len(samples), len(batches)) + self.assertEqual(len(samples), len(batches)) return True # We want all batches to finish for this test. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f2172b7a27d8..3e680f1030a7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,7 +157,7 @@ class DataTypeTests(unittest.TestCase): def test_data_type_eq(self): lt = LongType() lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEquals(lt, lt2) + self.assertEqual(lt, lt2) # regression test for SPARK-7978 def test_decimal_type(self): @@ -393,7 +393,7 @@ def test_infer_nested_schema(self): CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) df = self.sqlCtx.inferSchema(rdd) - self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] @@ -403,7 +403,7 @@ def test_create_dataframe_from_objects(self): def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") - self.assertEquals(Row(col=None), df.first()) + self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): from datetime import date, datetime @@ -519,14 +519,14 @@ def test_apply_schema_with_udt(self): StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT @@ -554,14 +554,14 @@ def test_parquet_with_udt(self): df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_column_operators(self): ci = self.df.key @@ -826,8 +826,8 @@ def test_infer_long_type(self): output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - self.assertEquals('a', df1.first().f1) - self.assertEquals(100000000000000, df1.first().f2) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) self.assertEqual(_infer_type(1), LongType()) self.assertEqual(_infer_type(2**10), LongType()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index cfea95b0dec7..e4e56fff3b3f 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -693,7 +693,7 @@ def check_output(n): # Verify that getActiveOrCreate() returns active context self.setupCalled = False - self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) self.assertFalse(self.setupCalled) # Verify that getActiveOrCreate() uses existing SparkContext From 00a2911c5bea67a1a4796fb1d6fd5d0a8ee79001 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 12:19:08 -0700 Subject: [PATCH 029/896] [SPARK-10540] Fixes flaky all-data-type test This PR breaks the original test case into multiple ones (one test case for each data type). In this way, test failure output can be much more readable. Within each test case, we build a table with two columns, one of them is for the data type to test, the other is an "index" column, which is used to sort the DataFrame and workaround [SPARK-10591] [1] [1]: https://issues.apache.org/jira/browse/SPARK-10591 Author: Cheng Lian Closes #8768 from liancheng/spark-10540/test-all-data-types. --- .../sql/sources/hadoopFsRelationSuites.scala | 109 +++++++----------- 1 file changed, 43 insertions(+), 66 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8ffcef85668d..d7504936d90e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,80 +100,57 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - ignore("test all data types") { - withTempPath { file => - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) - // TODO: add CalendarIntervalType to here once we can save it out. - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) - - // Generate data at the driver side. We need to materialize the data first and then - // create RDD. - val maybeDataGenerator = - RandomDataGenerator.forType( - dataType = schema, + private val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new MyDenseVectorUDT() + ).filter(supportsDataType) + + for (dataType <- supportedDataTypes) { + test(s"test all data types - $dataType") { + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, nullable = true, - seed = Some(System.nanoTime())) - val dataGenerator = - maybeDataGenerator - .getOrElse(fail(s"Failed to create data generator for schema $schema")) - val data = (1 to 10).map { i => - dataGenerator.apply() match { - case row: Row => row - case null => Row.fromSeq(Seq.fill(schema.length)(null)) - case other => - fail(s"Row or null is expected to be generated, " + - s"but a ${other.getClass.getCanonicalName} is generated.") + seed = Some(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") } - } - // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 10) - val df = sqlContext.createDataFrame(rdd, schema) + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - // All columns that have supported data types of this source. - val supportedColumns = schema.fields.collect { - case StructField(name, dataType, _, _) if supportsDataType(dataType) => name - } - val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) - - val dfToBeSaved = df.selectExpr(selectedColumns: _*) - - // Save the data out. - dfToBeSaved - .write - .format(dataSourceName) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .save(file.getCanonicalPath) + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) - val loadedDF = - sqlContext + val loadedDF = sqlContext .read .format(dataSourceName) - .schema(dfToBeSaved.schema) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .load(file.getCanonicalPath) - .selectExpr(selectedColumns: _*) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") - // Read the data back. - checkAnswer( - loadedDF, - dfToBeSaved - ) + checkAnswer(loadedDF, df) + } } } From c6f8135ee52202bd86adb090ab631e80330ea4df Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 18 Sep 2015 13:20:13 -0700 Subject: [PATCH 030/896] [SPARK-10539] [SQL] Project should not be pushed down through Intersect or Except #8742 Intersect and Except are both set operators and they use the all the columns to compare equality between rows. When pushing their Project parent down, the relations they based on would change, therefore not an equivalent transformation. JIRA: https://issues.apache.org/jira/browse/SPARK-10539 I added some comments based on the fix of https://github.com/apache/spark/pull/8742. Author: Yijie Shen Author: Yin Huai Closes #8823 from yhuai/fix_set_optimization. --- .../sql/catalyst/optimizer/Optimizer.scala | 37 ++++++++++--------- .../optimizer/SetOperationPushDownSuite.scala | 23 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++ 3 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 648a65e7c0eb..324f40a051c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,7 +85,22 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. */ object SetOperationPushDown extends Rule[LogicalPlan] { @@ -122,40 +137,26 @@ object SetOperationPushDown extends Rule[LogicalPlan] { Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into union + // Push down projection through UNION ALL case Project(projectList, u @ Union(left, right)) => val rewrites = buildRewrites(u) Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => val rewrites = buildRewrites(i) Intersect( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => val rewrites = buildRewrites(e) Except( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72..3fca47a023dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -60,23 +60,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c167999af580..1370713975f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -907,4 +907,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } } From 3a22b1004f527d54d399dd0225cd7f2f8ffad9c5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 13:47:14 -0700 Subject: [PATCH 031/896] [SPARK-10449] [SQL] Don't merge decimal types with incompatable precision or scales From JIRA: Schema merging should only handle struct fields. But currently we also reconcile decimal precision and scale information. Author: Holden Karau Closes #8634 from holdenk/SPARK-10449-dont-merge-different-precision. --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b29cf22dcb58..d6b436724b2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -373,10 +373,19 @@ object StructType extends AbstractDataType { StructType(newFields) case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) + DecimalType.Fixed(rightPrecision, rightScale)) => + if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { + DecimalType(leftPrecision, leftScale) + } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } else if (leftPrecision != rightPrecision) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } else { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"scala $leftScale and $rightScale") + } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt From 348d7c9a93dd00d3d1859342a8eb0aea2e77f597 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Sep 2015 13:48:41 -0700 Subject: [PATCH 032/896] [SPARK-9808] Remove hash shuffle file consolidation. Author: Reynold Xin Closes #8812 from rxin/SPARK-9808-1. --- .../shuffle/FileShuffleBlockResolver.scala | 178 ++---------------- .../apache/spark/storage/BlockManager.scala | 9 - .../org/apache/spark/storage/DiskStore.scala | 3 - .../hash/HashShuffleManagerSuite.scala | 110 ----------- docs/configuration.md | 10 - .../shuffle/ExternalShuffleBlockResolver.java | 4 - project/MimaExcludes.scala | 4 + 7 files changed, 17 insertions(+), 301 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index c057de9b3f4d..d9902f96dfd4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -28,10 +26,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,24 +39,7 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). @@ -71,26 +50,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,24 +72,16 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -142,58 +102,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,17 +125,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups.asScala; - file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks.asScala; - reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() } logInfo("Deleted all files for shuffle " + shuffleId) true @@ -229,10 +137,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -241,59 +145,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d31aa68eb695..bca3942f8c55 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -106,15 +106,6 @@ private[spark] class BlockManager( } } - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") - } - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f4595628216..feb9533604ff 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -154,9 +154,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { file.delete() } else { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala deleted file mode 100644 index 491dc3659e18..000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{File, FileWriter} - -import scala.language.reflectiveCalls - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} - -class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { - private val testConf = new SparkConf(false) - - private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) - val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) - assert(expected.offset === segment.getOffset) - assert(expected.length === segment.getLength) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val conf = new SparkConf(false) - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - sc = new SparkContext("local", "test", conf) - - val shuffleBlockResolver = - SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockResolver] - - val shuffle1 = shuffleBlockResolver.forMapTask(1, 1, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1", "value") - writer.write("test2", "value") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleBlockResolver.forMapTask(1, 2, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3", "value") - writer.write("test4", "vlue") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is - // concurrent read and writes happening to the same shuffle group. - - val shuffle3 = shuffleBlockResolver.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3", "value") - writer.write("test4", "value") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffleBlockResolver.removeShuffle(1) - } - - def writeToFile(file: File, numBytes: Int) { - val writer = new FileWriter(file, true) - for (i <- 0 until numBytes) writer.write(i) - writer.close() - } -} diff --git a/docs/configuration.md b/docs/configuration.md index 1a701f18881f..3700051efb44 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -390,16 +390,6 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. - - spark.shuffle.consolidateFiles - false - - If set to "true", consolidates intermediate files created during a shuffle. Creating fewer - files can improve filesystem performance for shuffles with large numbers of reduce tasks. It - is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option - might degrade performance on machines with many (>8) cores due to filesystem limitations. - - spark.shuffle.file.buffer 32k diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 79beec4429a9..c5f93bb47f55 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -50,9 +50,6 @@ * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); @@ -254,7 +251,6 @@ private void deleteExecutorDirs(String[] dirs) { * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockResolver. */ - // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1c96b0958586..814a11e588ce 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,6 +70,10 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ + Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") ) case v if v.startsWith("1.5") => Seq( From 8074208fa47fa654c1055c48cfa0d923edeeb04f Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Fri, 18 Sep 2015 15:40:58 -0700 Subject: [PATCH 033/896] [SPARK-10611] Clone Configuration for each task for NewHadoopRDD This patch attempts to fix the Hadoop Configuration thread safety issue for NewHadoopRDD in the same way SPARK-2546 fixed the issue for HadoopRDD. Author: Mingyu Kim Closes #8763 from mingyukim/mkim/SPARK-10611. --- .../org/apache/spark/rdd/BinaryFileRDD.scala | 5 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 6fec00dcd0d8..aedced7408cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -34,12 +34,13 @@ private[spark] class BinaryFileRDD[T]( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 174979aaeb23..2872b93b8730 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -44,7 +44,6 @@ private[spark] class NewHadoopPartition( extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -84,6 +83,27 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { @@ -104,7 +124,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -268,12 +292,13 @@ private[spark] class WholeTextFileRDD( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) From c8149ef2c57f5c47ab97ee8d8d58a216d4bc4156 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Sep 2015 16:23:05 -0700 Subject: [PATCH 034/896] [MINOR] [ML] override toString of AttributeGroup This makes equality test failures much more readable. mengxr Author: Eric Liang Author: Eric Liang Closes #8826 from ericl/attrgroupstr. --- .../scala/org/apache/spark/ml/attribute/AttributeGroup.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 457c15830fd3..2c29eeb01a92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -183,6 +183,8 @@ class AttributeGroup private ( sum = 37 * sum + attributes.map(_.toSeq).hashCode sum } + + override def toString: String = toMetadata.toString } /** From 22be2ae147a111e88896f6fb42ed46bbf108a99b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 18:42:20 -0700 Subject: [PATCH 035/896] [SPARK-10623] [SQL] Fixes ORC predicate push-down When pushing down a leaf predicate, ORC `SearchArgument` builder requires an extra "parent" predicate (any one among `AND`/`OR`/`NOT`) to wrap the leaf predicate. E.g., to push down `a < 1`, we must build `AND(a < 1)` instead. Fortunately, when actually constructing the `SearchArgument`, the builder will eliminate all those unnecessary wrappers. This PR is based on #8783 authored by zhzhan. I also took the chance to simply `OrcFilters` a little bit to improve readability. Author: Cheng Lian Closes #8799 from liancheng/spark-10623/fix-orc-ppd. --- .../spark/sql/hive/orc/OrcFilters.scala | 56 ++++++++----------- .../spark/sql/hive/orc/OrcQuerySuite.scala | 30 ++++++++++ 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index b3d9f7f71a27..27193f54d3a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.sources._ * and cannot be used anymore. */ private[orc] object OrcFilters extends Logging { - def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgumentFactory.newBuilder() - buildSearchArgument(conjunction, builder).map(_.build()) - } + def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + for { + // Combines all filters with `And`s to produce a single conjunction predicate + conjunction <- filters.reduceOption(And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { @@ -102,46 +104,32 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() - case EqualTo(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.equals(attribute, _)) + case EqualTo(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().equals(attribute, value).end()) - case EqualNullSafe(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.nullSafeEquals(attribute, _)) + case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThan(attribute, _)) + case LessThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThanEquals(attribute, _)) + case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThanEquals(attribute, _).end()) + case GreaterThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThan(attribute, _).end()) + case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThan(attribute, value).end()) case IsNull(attribute) => - Some(builder.isNull(attribute)) + Some(builder.startAnd().isNull(attribute).end()) case IsNotNull(attribute) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) => - Option(values) - .filter(_.forall(isSearchableLiteral)) - .map(builder.in(attribute, _)) + case In(attribute, values) if values.forall(isSearchableLiteral) => + Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8bc33fcf5d90..5eb39b112970 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -344,4 +344,34 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-10623 Enable ORC PPD") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + import testImplicits._ + + val path = dir.getCanonicalPath + sqlContext.range(10).coalesce(1).write.orc(path) + val df = sqlContext.read.orc(path) + + def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { + checkAnswer(df.where(pred), answer.map(Row(_))) + } + + checkPredicate('id === 5, Seq(5L)) + checkPredicate('id <=> 5, Seq(5L)) + checkPredicate('id < 5, 0L to 4L) + checkPredicate('id <= 5, 0L to 5L) + checkPredicate('id > 5, 6L to 9L) + checkPredicate('id >= 5, 5L to 9L) + checkPredicate('id.isNull, Seq.empty[Long]) + checkPredicate('id.isNotNull, 0L to 9L) + checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) + checkPredicate('id > 0 && 'id < 3, 1L to 2L) + checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) + checkPredicate(!('id > 3), 0L to 3L) + checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + } + } + } } From 7ff8d68cc19299e16dedfd819b9e96480fa6cf44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Sep 2015 23:58:25 -0700 Subject: [PATCH 036/896] [SPARK-10474] [SQL] Aggregation fails to allocate memory for pointer array When `TungstenAggregation` hits memory pressure, it switches from hash-based to sort-based aggregation in-place. However, in the process we try to allocate the pointer array for writing to the new `UnsafeExternalSorter` *before* actually freeing the memory from the hash map. This lead to the following exception: ``` java.io.IOException: Could not acquire 65536 bytes of memory at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.initializeForWriting(UnsafeExternalSorter.java:169) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:220) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:126) at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:257) at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.switchToSortBasedAggregation(TungstenAggregationIterator.scala:435) ``` Author: Andrew Or Closes #8827 from andrewor14/allocate-pointer-array. --- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++- .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../UnsafeFixedWidthAggregationMapSuite.scala | 49 ++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index fc364e0a895b..14b6aafdea7d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -159,7 +159,7 @@ public BoxedUnit apply() { /** * Allocates new sort data structures. Called when creating the sorter and after each spill. */ - private void initializeForWriting() throws IOException { + public void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); final long pointerArrayMemory = UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); @@ -187,6 +187,14 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + spill(true); + } + + /** + * Sort and spill the current records in response to memory pressure. + * @param shouldInitializeForWriting whether to allocate memory for writing after the spill + */ + public void spill(boolean shouldInitializeForWriting) throws IOException { assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), @@ -217,7 +225,9 @@ public void spill() throws IOException { // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - initializeForWriting(); + if (shouldInitializeForWriting) { + initializeForWriting(); + } } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 7db6b7ff50f2..b81f67a16b81 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,6 +85,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, // We will use the number of elements in the map as the initialSize of the // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, // we will use 1 as its initial size if the map is empty. + // TODO: track pointer array memory used by this in-memory sorter! final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); @@ -123,8 +124,13 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes, inMemSorter); - sorter.spill(); + // Note: This spill doesn't actually release any memory, so if we try to allocate a new + // pointer array immediately after the spill then we may fail to acquire sufficient space + // for it (SPARK-10474). For this reason, we must initialize for writing explicitly *after* + // we have actually freed memory from our map. + sorter.spill(false /* initialize for writing */); map.free(); + sorter.initializeForWriting(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d1f0b2b1fc52..ada4d42f991c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,9 +23,10 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -325,7 +326,7 @@ class UnsafeFixedWidthAggregationMapSuite // At here, we also test if copy is correct. iter.getKey.copy() iter.getValue.copy() - count += 1; + count += 1 } // 1 record was from the map and 4096 records were explicitly inserted. @@ -333,4 +334,48 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val smm = ShuffleMemoryManager.createForTesting(65536) + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + smm, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + // Insert into the map until we've run out of space + val rand = new Random(42) + var hasSpace = true + while (hasSpace) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + if (buf == null) { + hasSpace = false + } else { + buf.setInt(0, str.length) + } + } + + // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte + assert(smm.tryToAcquire(1) === 0) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } From d507f9c0b7f7a524137a694ed6443747aaf90463 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 19 Sep 2015 01:59:36 -0700 Subject: [PATCH 037/896] [SPARK-10584] [SQL] [DOC] Documentation about the compatible Hive version is wrong. In Spark 1.5.0, Spark SQL is compatible with Hive 0.12.0 through 1.2.1 but the documentation is wrong. /CC yhuai Author: Kousuke Saruta Closes #8776 from sarutak/SPARK-10584-2. --- docs/sql-programming-guide.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0b911d20724..82d4243cc6b2 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1954,7 +1954,7 @@ without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. +in Hive 1.2.1 You can test the JDBC server with the beeline script that comes with either Spark or Hive 1.2.1. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -2260,8 +2260,10 @@ Several caching related features are not supported yet: ## Compatibility with Apache Hive -Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark -SQL is based on Hive 0.12.0 and 0.13.1. +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. +Currently Hive SerDes and UDFs are based on Hive 1.2.1, +and Spark SQL can be connected to different versions of Hive Metastore +(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). #### Deploying in Existing Hive Warehouses From d83b6aae8b4357c56779cc98804eb350ab8af62d Mon Sep 17 00:00:00 2001 From: Alexis Seigneurin Date: Sat, 19 Sep 2015 12:01:22 +0100 Subject: [PATCH 038/896] Fixed links to the API Submitting this change on the master branch as requested in https://github.com/apache/spark/pull/8819#issuecomment-141505941 Author: Alexis Seigneurin Closes #8838 from aseigneurin/patch-2. --- docs/ml-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5d7f990021f..0427ac6695aa 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -619,13 +619,13 @@ for row in selected.collect(): An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. From e789000b88a6bd840f821c53f42c08b97dc02496 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 18:22:43 -0700 Subject: [PATCH 039/896] [SPARK-10155] [SQL] Change SqlParser to object to avoid memory leak Since `scala.util.parsing.combinator.Parsers` is thread-safe since Scala 2.10 (See [SI-4929](https://issues.scala-lang.org/browse/SI-4929)), we can change SqlParser to object to avoid memory leak. I didn't change other subclasses of `scala.util.parsing.combinator.Parsers` because there is only one instance in one SQLContext, which should not be an issue. Author: zsxwing Closes #8357 from zsxwing/sql-memory-leak. --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ParserDialect.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++--- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 5898a5f93f38..2bac08eac4fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def parse(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 554fb4eb25eb..e21d3c05464b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -61,7 +61,7 @@ abstract class ParserDialect { */ private[spark] class DefaultParserDialect extends ParserDialect { @transient - protected val sqlParser = new SqlParser + protected val sqlParser = SqlParser override def parse(sqlText: String): LogicalPlan = { sqlParser.parse(sqlText) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c957..dfab2398857e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -37,9 +37,9 @@ import org.apache.spark.unsafe.types.CalendarInterval * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser with DataTypeParser { +object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - def parseExpression(input: String): Expression = { + def parseExpression(input: String): Expression = synchronized { // Initialize the Keywords. initLexical phrase(projection)(new lexical.Scanner(input)) match { @@ -48,7 +48,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } - def parseTableIdentifier(input: String): TableIdentifier = { + def parseTableIdentifier(input: String): TableIdentifier = synchronized { // Initialize the Keywords. initLexical phrase(tableIdentifier)(new lexical.Scanner(input)) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3e61123c145c..8f737c202393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -720,7 +720,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(new SqlParser().parseExpression(expr)) + Column(SqlParser.parseExpression(expr)) }: _*) } @@ -745,7 +745,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -769,7 +769,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 745bb4ec9cf1..03e973666e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -163,7 +163,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(new SqlParser().parseTableIdentifier(tableName)) + insertInto(SqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -197,7 +197,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + saveAsTable(SqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e3fdd782e6ff..f099940800cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -590,7 +590,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -636,7 +636,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -732,7 +732,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(new SqlParser().parseTableIdentifier(tableName)) + table(SqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60d9c509104d..2467b4e48415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -823,7 +823,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d37ba5ddc2d8..c12a73486332 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -291,12 +291,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -311,7 +311,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a5569b0a444..0c1b41e3377e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -199,7 +199,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive options: Map[String, String], isExternal: Boolean): Unit = { createDataSourceTable( - new SqlParser().parseTableIdentifier(tableName), + SqlParser.parseTableIdentifier(tableName), userSpecifiedSchema, partitionColumns, provider, @@ -375,7 +375,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { - hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + hiveDefaultTableFilePath(SqlParser.parseTableIdentifier(tableName)) } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { From 2117eea71ece825fbc3797c8b38184ae221f5223 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 19 Sep 2015 21:40:21 -0700 Subject: [PATCH 040/896] [SPARK-10710] Remove ability to disable spilling in core and SQL It does not make much sense to set `spark.shuffle.spill` or `spark.sql.planner.externalSort` to false: I believe that these configurations were initially added as "escape hatches" to guard against bugs in the external operators, but these operators are now mature and well-tested. In addition, these configurations are not handled in a consistent way anymore: SQL's Tungsten codepath ignores these configurations and will continue to use spilling operators. Similarly, Spark Core's `tungsten-sort` shuffle manager does not respect `spark.shuffle.spill=false`. This pull request removes these configurations, adds warnings at the appropriate places, and deletes a large amount of code which was only used in code paths that did not support spilling. Author: Josh Rosen Closes #8831 from JoshRosen/remove-ability-to-disable-spilling. --- .../scala/org/apache/spark/Aggregator.scala | 59 +++++-------------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 40 ++++--------- .../shuffle/hash/HashShuffleManager.scala | 8 ++- .../shuffle/sort/SortShuffleManager.scala | 10 +++- .../util/collection/ExternalSorter.scala | 6 -- .../spark/deploy/SparkSubmitSuite.scala | 22 +++---- docs/configuration.md | 14 +---- docs/sql-programming-guide.md | 7 --- python/pyspark/rdd.py | 25 +++----- python/pyspark/shuffle.py | 30 ---------- python/pyspark/tests.py | 13 +--- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +-- .../spark/sql/execution/SparkStrategies.scala | 2 - .../apache/spark/sql/execution/commands.scala | 9 +++ .../org/apache/spark/sql/execution/sort.scala | 30 +--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++------ .../execution/RowFormatConvertersSuite.scala | 2 +- .../spark/sql/execution/SortSuite.scala | 4 +- 18 files changed, 81 insertions(+), 234 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 289aab9bd9e5..7196e57d5d2e 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,59 +34,30 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } /** Update task metrics after populating the external map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7bad749d5832..935c3babd8ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -128,8 +128,6 @@ class CoGroupedRDD[K: ClassTag]( override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -150,34 +148,16 @@ class CoGroupedRDD[K: ClassTag]( rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409d..0b46634b8b46 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,13 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b..476cc1f303da 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,11 +19,17 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 31230d5978b2..2a30f751ff03 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -116,8 +116,6 @@ private[spark] class ExternalSorter[K, V, C]( private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -229,10 +227,6 @@ private[spark] class ExternalSorter[K, V, C]( * @param usingMap whether we're using a map or buffer as our current in-memory collection */ private def maybeSpillCollection(usingMap: Boolean): Unit = { - if (!spillingEnabled) { - return - } - var estimatedSize = 0L if (usingMap) { estimatedSize = map.estimateSize() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1110ca6051a4..1fd470cd3b01 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -147,7 +147,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -166,7 +166,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") sysProps.keys should not contain ("spark.jars") } @@ -185,7 +185,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -206,7 +206,7 @@ class SparkSubmitSuite sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone cluster mode") { @@ -229,7 +229,7 @@ class SparkSubmitSuite "--supervise", "--driver-memory", "4g", "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -253,9 +253,9 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.memory") sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.ui.enabled") sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -266,7 +266,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -277,7 +277,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -288,7 +288,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -299,7 +299,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { diff --git a/docs/configuration.md b/docs/configuration.md index 3700051efb44..5ec097c78aa3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -69,7 +69,7 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false +./bin/spark-submit --name "My app" --master local[4] --conf spark.eventLog.enabled=false --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} @@ -449,8 +449,8 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.memoryFraction 0.2 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will begin to spill to disk. If spills are often, consider increasing this value at the expense of spark.storage.memoryFraction. @@ -483,14 +483,6 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. - - spark.shuffle.spill - true - - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. - - spark.shuffle.spill.compress true diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 82d4243cc6b2..7ae9244c271e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1936,13 +1936,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - spark.sql.planner.externalSort - true - - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. - - # Distributed SQL Engine diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ab5aab1e115f..73d7d9a5692a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,7 +48,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ +from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync @@ -580,12 +580,11 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -610,12 +609,11 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: @@ -1770,13 +1768,11 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = self._can_spill() memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1784,8 +1780,7 @@ def combineLocally(iterator): shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() @@ -1824,9 +1819,6 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - def _can_spill(self): - return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" - def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1857,14 +1849,12 @@ def mergeCombiners(a, b): a.extend(b) return a - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combine(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1872,8 +1862,7 @@ def combine(iterator): shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - merger = ExternalGroupBy(agg, memory, serializer)\ - if spill else InMemoryMerger(agg) + merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b8118bdb7ca7..e974cda9fc3e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -131,36 +131,6 @@ def items(self): raise NotImplementedError -class InMemoryMerger(Merger): - - """ - In memory merger based on in-memory dict. - """ - - def __init__(self, aggregator): - Merger.__init__(self, aggregator) - self.data = {} - - def mergeValues(self, iterator): - """ Combine the items by creator and combiner """ - # speed up attributes lookup - d, creator = self.data, self.agg.createCombiner - comb = self.agg.mergeValue - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else creator(v) - - def mergeCombiners(self, iterator): - """ Merge the combined items by mergeCombiner """ - # speed up attributes lookup - d, comb = self.data, self.agg.mergeCombiners - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - - def items(self): - """ Return the merged items ad iterator """ - return iter(self.data.items()) - - def _compressed_serializer(self, serializer=None): # always use PickleSerializer to simplify implementation ser = PickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 647504c32f15..f11aaf001c8d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,7 +62,7 @@ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -95,17 +95,6 @@ def setUp(self): lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) - def test_in_memory(self): - m = InMemoryMerger(self.agg) - m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d08..b9fb90d96420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -330,11 +330,6 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", - defaultValue = Some(true), - doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + - " memory.") - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") @@ -422,6 +417,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" } } @@ -476,8 +472,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5e40d7768904..41b215c79296 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && TungstenSort.supportsSchema(child.schema)) { execution.TungstenSort(sortExprs, global, child) - } else if (sqlContext.conf.externalSortEnabled) { - execution.ExternalSort(sortExprs, global, child) } else { execution.Sort(sortExprs, global, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 95209e663451..af28e2dfa418 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -105,6 +105,15 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 40ef7c3b5353..27f26245a5ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -31,38 +31,12 @@ import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} // This file defines various sort operators. //////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - /** * Performs a sort, spilling to disk as needed. * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. */ -case class ExternalSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan) @@ -93,7 +67,7 @@ case class ExternalSort( } /** - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Optimized version of [[Sort]] that operates on binary data (implemented as part of * Project Tungsten). * * @param global when true performs a global sort of all partitions by shuffling the data first diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f9981356f364..05b4127cbcaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -581,28 +581,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() } } @@ -1731,10 +1715,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("external sorting updates peak execution memory") { - withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 4492e37ad01f..5dc37e5c3c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,7 +32,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 3073d492e613..847c188a3033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -36,13 +36,13 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } From 1aa9e50256988533fa54584b49dbc408a14438ee Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 20 Sep 2015 16:05:12 -0700 Subject: [PATCH 041/896] [SPARK-5905] [MLLIB] Note requirements for certain RowMatrix methods in docs Note methods that fail for cols > 65535; note that SVD does not require n >= m CC mengxr Author: Sean Owen Closes #8839 from srowen/SPARK-5905. --- .../spark/mllib/linalg/distributed/RowMatrix.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e55ef26858ad..7c7d900af3d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -109,7 +109,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. + * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with + * more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -150,7 +151,8 @@ class RowMatrix @Since("1.0.0") ( * - s is a Vector of size k, holding the singular values in descending order, * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * We assume n is smaller than m. The singular values and the right singular vectors are derived + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined @@ -320,7 +322,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. + * Computes the covariance matrix, treating each row as an observation. Note that this cannot + * be computed on matrices with more than 65535 columns. * @return a local dense matrix of size n x n */ @Since("1.0.0") @@ -374,6 +377,8 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * + * Note that this cannot be computed on matrices with more than 65535 columns. + * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ From 0c498717ba9622b6c889e701e8eed5ef9215c030 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 20 Sep 2015 16:16:31 -0700 Subject: [PATCH 042/896] [SPARK-10715] [ML] Duplicate initialization flag in WeightedLeastSquare There are duplicate set of initialization flag in `WeightedLeastSquares#add`. `initialized` is already set in `init(Int)`. Author: lewuathe Closes #8837 from Lewuathe/duplicate-initialization-flag. --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 0ff8931b0bab..4374e9963156 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -193,7 +193,6 @@ private[ml] object WeightedLeastSquares { val ak = a.size if (!initialized) { init(ak) - initialized = true } assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") count += 1L From 01440395176bdbb2662480f03b27851cb860f385 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Sun, 20 Sep 2015 22:55:24 -0700 Subject: [PATCH 043/896] [SPARK-10631] [DOCUMENTATION, MLLIB, PYSPARK] Added documentation for few APIs There are some missing API docs in pyspark.mllib.linalg.Vector (including DenseVector and SparseVector). We should add them based on their Scala counterparts. Author: vinodkc Closes #8834 from vinodkc/fix_SPARK-10631. --- python/pyspark/mllib/linalg/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 4829acb16ed8..f929e3e96fbe 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -301,11 +301,14 @@ def __reduce__(self): return DenseVector, (self.array.tostring(),) def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ return np.count_nonzero(self.array) def norm(self, p): """ - Calculte the norm of a DenseVector. + Calculates the norm of a DenseVector. >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) @@ -397,10 +400,16 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): + """ + Returns an numpy.ndarray + """ return self.array @property def values(self): + """ + Returns a list of values + """ return self.array def __getitem__(self, item): @@ -479,8 +488,8 @@ def __init__(self, size, *args): :param size: Size of the vector. :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly i - ncreasing indices and a list of corresponding values [index, ...], + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) @@ -521,11 +530,14 @@ def __init__(self, size, *args): raise TypeError("indices array must be sorted") def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ return np.count_nonzero(self.values) def norm(self, p): """ - Calculte the norm of a SparseVector. + Calculates the norm of a SparseVector. >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) @@ -797,7 +809,7 @@ def sparse(size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, + :param args: Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indices and values. >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) From 20a61dbd9b57957fcc5b58ef8935533914172b07 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 18:53:28 +0100 Subject: [PATCH 044/896] [SPARK-10626] [MLLIB] create java friendly method for random rdd SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method. Author: Holden Karau Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD. --- .../spark/mllib/random/RandomRDDs.scala | 52 ++++++++++++++++++- .../mllib/random/JavaRandomRDDsSuite.java | 30 +++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 4dd5ea214d67..f8ff26b5795b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD @@ -381,7 +382,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi @Since("1.1.0") @@ -394,6 +395,55 @@ object RandomRDDs { new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * :: DeveloperApi :: + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. + * + * @param jsc JavaSparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int, + seed: Long): JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed) + JavaRDD.fromRDD(rdd) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong()) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, 0); + } + // TODO Generate RDD[Vector] from multivariate distributions. /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 33d81b1e9592..fce5f6712f46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.random; +import java.io.Serializable; import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; @@ -231,4 +232,33 @@ public void testGammaVectorRDD() { } } + @Test + public void testArbitrary() { + long size = 10; + long seed = 1L; + int numPartitions = 0; + StringGenerator gen = new StringGenerator(); + JavaRDD rdd1 = randomJavaRDD(sc, gen, size); + JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(size, rdd.count()); + Assert.assertEquals(2, rdd.first().length()); + } + } +} + +// This is just a test generator, it always returns a string of 42 +class StringGenerator implements RandomDataGenerator, Serializable { + @Override + public String nextValue() { + return "42"; + } + @Override + public StringGenerator copy() { + return new StringGenerator(); + } + @Override + public void setSeed(long seed) { + } } From ebbf85f07bb8de0d566f1ae4b41f26421180bebe Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 21 Sep 2015 11:39:04 -0700 Subject: [PATCH 045/896] [SPARK-7989] [SPARK-10651] [CORE] [TESTS] Increase timeout to fix flaky tests I noticed only one block manager registered with master in an unsuccessful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/3534/) ``` 15/09/16 13:02:30.981 pool-1-thread-1-ScalaTest-running-BroadcastSuite INFO SparkContext: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 13:02:38.133 sparkDriver-akka.actor.default-dispatcher-19 INFO BlockManagerMasterEndpoint: Registering block manager localhost:48196 with 530.3 MB RAM, BlockManagerId(0, localhost, 48196) ``` In addition, the first block manager needed 7+ seconds to start. But the test expected 2 block managers so it failed. However, there was no exception in this log file. So I checked a successful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3536/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/) and it needed 4-5 seconds to set up the local cluster: ``` 15/09/16 18:11:27.738 sparkWorker1-akka.actor.default-dispatcher-5 INFO Worker: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 18:11:30.838 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:54202 with 530.3 MB RAM, BlockManagerId(1, localhost, 54202) 15/09/16 18:11:32.112 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:32955 with 530.3 MB RAM, BlockManagerId(0, localhost, 32955) ``` In this build, the first block manager needed only 3+ seconds to start. Comparing these two builds, I guess it's possible that the local cluster in `BroadcastSuite` cannot be ready in 10 seconds if the Jenkins worker is busy. So I just increased the timeout to 60 seconds to see if this can fix the issue. Author: zsxwing Closes #8813 from zsxwing/fix-BroadcastSuite. --- .../scala/org/apache/spark/ExternalShuffleServiceSuite.scala | 2 +- .../test/scala/org/apache/spark/broadcast/BroadcastSuite.scala | 2 +- .../apache/spark/scheduler/SparkListenerWithClusterSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index e846a72c888c..231f4631e0a4 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -61,7 +61,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index fb7a8ae3f9d4..ba21075ce6be 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -311,7 +311,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d1e23ed527ff..9fa885938291 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) From ca9fe540fe04e2e230d1e76526b5502bab152914 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 21 Sep 2015 19:46:39 +0100 Subject: [PATCH 046/896] [SPARK-10662] [DOCS] Code snippets are not properly formatted in tables * Backticks are processed properly in Spark Properties table * Removed unnecessary spaces * See http://people.apache.org/~pwendell/spark-nightly/spark-master-docs/latest/running-on-yarn.html Author: Jacek Laskowski Closes #8795 from jaceklaskowski/docs-yarn-formatting. --- docs/configuration.md | 97 +++++++++++++++-------------- docs/programming-guide.md | 100 +++++++++++++++--------------- docs/running-on-mesos.md | 14 ++--- docs/running-on-yarn.md | 106 ++++++++++++++++---------------- docs/sql-programming-guide.md | 16 ++--- docs/submitting-applications.md | 8 +-- 6 files changed, 171 insertions(+), 170 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5ec097c78aa3..b22587c70316 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,20 +34,20 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may actually require one to prevent any sort of starvation issues. -Properties that specify some time duration should be configured with a unit of time. +Properties that specify some time duration should be configured with a unit of time. The following format is accepted: - + 25ms (milliseconds) 5s (seconds) 10m or 10min (minutes) 3h (hours) 5d (days) 1y (years) - - + + Properties that specify a byte size should be configured with a unit of size. The following format is accepted: @@ -140,7 +140,7 @@ of the most common options to set are: Amount of memory to use for the driver process, i.e. where SparkContext is initialized. (e.g. 1g, 2g). - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option @@ -207,7 +207,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-class-path command line option or in + Instead, please set this through the --driver-class-path command line option or in your default properties file. @@ -216,10 +216,10 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-java-options command line option or in + Instead, please set this through the --driver-java-options command line option or in your default properties file. @@ -228,10 +228,10 @@ Apart from these, the following properties are also available, and may be useful (none) Set a special library path to use when launching the driver JVM. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-library-path command line option or in + Instead, please set this through the --driver-library-path command line option or in your default properties file. @@ -242,7 +242,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in the the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. - + This is used in cluster mode only. @@ -250,8 +250,8 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to prepend to the classpath of executors. This exists primarily for - backwards-compatibility with older versions of Spark. Users typically should not need to set + Extra classpath entries to prepend to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -259,9 +259,9 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Heap size settings can be set with spark.executor.memory. @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by - passing a profiler class in as a parameter to the `SparkContext` constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -460,11 +460,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.enabled false - Enables the external shuffle service. This service preserves the shuffle files written by - executors so the executors can be safely removed. This must be enabled if + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if spark.dynamicAllocation.enabled is "true". The external shuffle service must be set up in order to enable it. See - dynamic allocation + dynamic allocation configuration and setup documentation for more information. @@ -747,9 +747,9 @@ Apart from these, the following properties are also available, and may be useful 1 in YARN mode, all the available cores on the worker in standalone mode. The number of cores to use on each executor. For YARN and standalone mode only. - - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one executor per application will run on each worker. @@ -893,14 +893,14 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval 1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. @@ -909,9 +909,9 @@ Apart from these, the following properties are also available, and may be useful 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with `spark.akka.heartbeat.interval` if you need to. + this along with spark.akka.heartbeat.interval if you need to. @@ -978,7 +978,7 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout 120s - Default timeout for all network interactions. This config will be used in place of + Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or @@ -991,8 +991,8 @@ Apart from these, the following properties are also available, and may be useful Maximum number of retries when binding to a port before giving up. When a port is given a specific value (non 0), each subsequent retry will - increment the port used in the previous attempt by 1 before retrying. This - essentially allows it to try a range of ports from the start port specified + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified to port + maxRetries. @@ -1191,7 +1191,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout 60s - If dynamic allocation is enabled and an executor has been idle for more than this duration, + If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this description. @@ -1424,11 +1424,11 @@ Apart from these, the following properties are also available, and may be useful false Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). - This enables the Spark Streaming to control the receiving rate based on the + This enables the Spark Streaming to control the receiving rate based on the current batch scheduling delays and processing times so that the system receives - only as fast as the system can process. Internally, this dynamically sets the + only as fast as the system can process. Internally, this dynamically sets the maximum receiving rate of receivers. This rate is upper bounded by the values - `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + spark.streaming.receiver.maxRate and spark.streaming.kafka.maxRatePerPartition if they are set (see below). @@ -1542,15 +1542,15 @@ The following variables can be set in `spark-env.sh`: Environment VariableMeaning JAVA_HOME - Location where Java is installed (if it's not on your default `PATH`). + Location where Java is installed (if it's not on your default PATH). PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is `python`). + Python binary executable to use for PySpark in both driver and workers (default is python). PYSPARK_DRIVER_PYTHON - Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). SPARK_LOCAL_IP @@ -1580,4 +1580,3 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config To specify a different configuration directory other than the default "SPARK_HOME/conf", you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. - diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4cf83bb39263..8ad238315f12 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -182,8 +182,8 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. You can also add dependencies -(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -217,7 +217,7 @@ context connects to using the `--master` argument, and you can add Python .zip, to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: @@ -249,8 +249,8 @@ the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support $ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} -After the IPython Notebook server is launched, you can create a new "Python 2" notebook from -the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the IPython notebook. @@ -418,9 +418,9 @@ Apart from text files, Spark's Python API also supports several other data forma **Writable Support** -PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the -resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, -PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following Writables are automatically converted: @@ -435,9 +435,9 @@ Writables are automatically converted:
    MapWritabledict
    -Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, -users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default -converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get Python `array.array` for arrays of primitive types, users need to specify custom converters. **Saving and Loading SequenceFiles** @@ -454,7 +454,7 @@ classes can be specified, but for standard Writables this is not required. **Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: @@ -474,15 +474,15 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. -A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided -for this. Simply extend this trait and implement your transformation code in the ```convert``` -method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark +A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided +for this. Simply extend this trait and implement your transformation code in the ```convert``` +method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark classpath. -See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters. @@ -758,7 +758,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
    @@ -777,7 +777,7 @@ println("Counter value: " + counter)
    {% highlight java %} int counter = 0; -JavaRDD rdd = sc.parallelize(data); +JavaRDD rdd = sc.parallelize(data); // Wrong: Don't do this!! rdd.foreach(x -> counter += x); @@ -803,7 +803,7 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. @@ -813,9 +813,9 @@ To ensure well-defined behavior in these sorts of scenarios one should use an [` In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. -#### Printing elements of an RDD +#### Printing elements of an RDD Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. - + ### Working with Key-Value Pairs
    @@ -859,7 +859,7 @@ only available on RDDs of key-value pairs. The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. -In Java, key-value pairs are represented using the +In Java, key-value pairs are represented using the [scala.Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) class from the Scala standard library. You can simply call `new Tuple2(a, b)` to create a tuple, and access its fields later with `tuple._1()` and `tuple._2()`. @@ -974,7 +974,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
    Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or aggregateByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
    Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -1025,7 +1025,7 @@ for details. repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, - sort records by their keys. This is more efficient than calling repartition and then sorting within + sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1038,7 +1038,7 @@ RDD API doc [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), [Python](api/python/pyspark.html#pyspark.RDD), [R](api/R/index.html)) - + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1094,7 +1094,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1118,13 +1118,13 @@ co-located to compute the result. In Spark, data is generally not distributed across partitions to be in the necessary place for a specific operation. During computations, a single task will operate on a single partition - thus, to organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an -all-to-all operation. It must read from all partitions to find all the values for all keys, -and then bring together values across partitions to compute the final result for each key - +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - this is called the **shuffle**. Although the set of elements in each partition of newly shuffled data will be deterministic, and so -is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably -ordered data following shuffle then it's possible to use: +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: * `mapPartitions` to sort each partition using, for example, `.sorted` * `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning @@ -1141,26 +1141,26 @@ network I/O. To organize data for the shuffle, Spark generates sets of tasks - * organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from MapReduce and does not directly relate to Spark's `map` and `reduce` operations. -Internally, results from individual map tasks are kept in memory until they can't fit. Then, these -are sorted based on the target partition and written to a single file. On the reduce side, tasks +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks read the relevant sorted blocks. - -Certain shuffle operations can consume significant amounts of heap memory since they employ -in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations -generate these on the reduce side. When data does not fit in memory Spark will spill these tables + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are preserved until the corresponding RDDs are no longer used and are garbage collected. -This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references -to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the -'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). ## RDD Persistence @@ -1246,7 +1246,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1345,7 +1345,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1474,8 +1474,8 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
    -For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator -will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: @@ -1486,7 +1486,7 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being {% highlight scala %} val accum = sc.accumulator(0) data.map { x => accum += x; f(x) } -// Here, accum is still 0 because no actions have caused the `map` to be computed. +// Here, accum is still 0 because no actions have caused the map to be computed. {% endhighlight %}
    @@ -1553,7 +1553,7 @@ Several changes were made to the Java API: code that `extends Function` should `implement Function` instead. * New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning +* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning `(Key, List)` pairs to `(Key, Iterable)`.
    diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 330c159c67bc..460a66f37dd6 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -245,7 +245,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - If set to "true", runs over Mesos clusters in + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use @@ -254,16 +254,16 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores - 0 + 0 Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. spark.mesos.mesosExecutor.cores - 1.0 + 1.0 (Fine-grained mode only) Number of cores to give each Mesos executor. This does not include the cores used to run the Spark tasks. In other words, even if no Spark task @@ -287,7 +287,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the list of volumes which will be mounted into the Docker image, which was set using spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings following the form passed to docker run -v. That is they take the form: + mappings following the form passed to docker run -v. That is they take the form:
    [host_path:]container_path[:ro|:rw]
    @@ -318,7 +318,7 @@ See the [configuration page](configuration.html) for information on Spark config executor memory * 0.10, with minimum of 384 The amount of additional memory, specified in MB, to be allocated per executor. By default, - the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value. @@ -339,7 +339,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.secret - (none)/td> + (none) Set the secret with which Spark framework will use to authenticate with Mesos. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3a961d245f3d..0e25ccf512c0 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -23,7 +23,7 @@ Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.ht To launch a Spark application in `yarn-cluster` mode: $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - + For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ @@ -43,7 +43,7 @@ To launch a Spark application in `yarn-client` mode, do the same, but replace `y ## Adding Other JARs -In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ --master yarn-cluster \ @@ -64,16 +64,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Debugging your Application -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the `yarn logs` command. yarn logs -applicationId - + will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +large value (e.g. `36000`), and then access the application cache through `yarn.nodemanager.local-dirs` on the nodes on which containers are launched. This directory contains the launch script, JARs, and all environment variables used for launching each container. This process is useful for debugging classpath problems in particular. (Note that enabling this requires admin privileges on cluster @@ -92,7 +92,7 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. #### Spark Properties @@ -100,24 +100,26 @@ If you need a reference to the proper location to put log files in the YARN so t Property NameDefaultMeaning spark.yarn.am.memory - 512m + 512m Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). In cluster mode, use spark.driver.memory instead. +

    + Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. spark.driver.cores - 1 + 1 Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. spark.yarn.am.cores - 1 + 1 Number of cores to use for the YARN Application Master in client mode. In cluster mode, use spark.driver.cores instead. @@ -125,39 +127,39 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.am.waitTime - 100s + 100s - In `yarn-cluster` mode, time for the application master to wait for the - SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait + In yarn-cluster mode, time for the YARN Application Master to wait for the + SparkContext to be initialized. In yarn-client mode, time for the YARN Application Master to wait for the driver to connect to it. spark.yarn.submit.file.replication - The default HDFS replication (usually 3) + The default HDFS replication (usually 3) HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. spark.yarn.preserve.staging.files - false + false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. spark.yarn.scheduler.heartbeat.interval-ms - 3000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. - The value is capped at half the value of YARN's configuration for the expiry interval - (yarn.am.liveness-monitor.expiry-interval-ms). + The value is capped at half the value of YARN's configuration for the expiry interval, i.e. + yarn.am.liveness-monitor.expiry-interval-ms. spark.yarn.scheduler.initial-allocation.interval - 200ms + 200ms The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager when there are pending container allocation requests. It should be no larger than @@ -177,8 +179,8 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. - For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. + The address of the Spark history server, e.g. host.com:18080. The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For example, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to ${hadoopconf-yarn.resourcemanager.hostname}:18080. @@ -197,42 +199,42 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances - 2 + 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. spark.yarn.am.port (random) - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. spark.yarn.queue - default + default The name of the YARN queue to which the application is submitted. @@ -245,18 +247,18 @@ If you need a reference to the proper location to put log files in the YARN so t By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to "hdfs:///some/path". + set this configuration to hdfs:///some/path. spark.yarn.access.namenodes (none) - A list of secure HDFS namenodes your Spark application is going to access. For - example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. - The Spark application must have acess to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + A comma-separated list of secure HDFS namenodes your Spark application is going to access. For + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. + The Spark application must have access to the namenodes listed and Kerberos must + be properly configured to be able to access them (either in the same realm or in + a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -264,18 +266,18 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) - Add the environment variable specified by EnvironmentVariableName to the - Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In `yarn-cluster` mode this controls - the environment of the SPARK driver and in `yarn-client` mode it only controls - the environment of the executor launcher. + Add the environment variable specified by EnvironmentVariableName to the + Application Master process launched on YARN. The user can specify multiple of + these and to set multiple environment variables. In yarn-cluster mode this controls + the environment of the Spark driver and in yarn-client mode it only controls + the environment of the executor launcher. spark.yarn.containerLauncherMaxThreads - 25 + 25 - The maximum number of threads to use in the application master for launching executor containers. + The maximum number of threads to use in the YARN Application Master for launching executor containers. @@ -283,19 +285,19 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use `spark.driver.extraJavaOptions` instead. + In cluster mode, use spark.driver.extraJavaOptions instead. spark.yarn.am.extraLibraryPath (none) - Set a special library path to use when launching the application master in client mode. + Set a special library path to use when launching the YARN Application Master in client mode. spark.yarn.maxAppAttempts - yarn.resourcemanager.am.max-attempts in YARN + yarn.resourcemanager.am.max-attempts in YARN The maximum number of attempts that will be made to submit the application. It should be no larger than the global number of max attempts in the YARN configuration. @@ -303,10 +305,10 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.submit.waitAppCompletion - true + true In YARN cluster mode, controls whether the client waits to exit until the application completes. - If set to true, the client process will stay alive reporting the application's status. + If set to true, the client process will stay alive reporting the application's status. Otherwise, the client process will exit after submission. @@ -332,7 +334,7 @@ If you need a reference to the proper location to put log files in the YARN so t (none) The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, for renewing the login tickets and the delegation tokens periodically. @@ -371,14 +373,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.security.tokens.${service}.enabled - true + true Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. By default, delegation tokens for all supported services are retrieved when those services are configured, but it's possible to disable that behavior if it somehow conflicts with the application being run.

    - Currently supported services are: hive, hbase + Currently supported services are: hive, hbase @@ -387,5 +389,5 @@ If you need a reference to the proper location to put log files in the YARN so t - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9244c271e..a1cbc7de97c6 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). @@ -1706,8 +1706,8 @@ The following options can be used to configure the version of Hive that is used either 1.2.1 or not defined.

  • maven
  • Use Hive jars of specified version downloaded from Maven repositories. This configuration - is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + is not generally recommended for production deployments. +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1806,7 +1806,7 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( +val jdbcDF = sqlContext.read.format("jdbc").options( Map("url" -> "jdbc:postgresql:dbserver", "dbtable" -> "schema.tablename")).load() {% endhighlight %} @@ -2023,11 +2023,11 @@ options. - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with code generation for expression evaluation. These features can both be disabled by setting - `spark.sql.tungsten.enabled` to `false. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.tungsten.enabled` to `false`. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 7ea4d6f1a3f8..915be0f47915 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -103,7 +103,7 @@ run it with `--help`. Here are a few examples of common options: export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ # can also be `yarn-client` for client mode + --master yarn-cluster \ # can also be yarn-client for client mode --executor-memory 20G \ --num-executors 50 \ /path/to/examples.jar \ @@ -174,9 +174,9 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. -Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates -with `--packages`. All transitive dependencies will be handled when using this command. Additional -repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries From 331f0b10f78a37d96d3e573d211d74a0935265db Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Mon, 21 Sep 2015 12:09:00 -0700 Subject: [PATCH 047/896] [SPARK-9642] [ML] LinearRegression should supported weighted data In many modeling application, data points are not necessarily sampled with equal probabilities. Linear regression should support weighting which account the over or under sampling. work in progress. Author: Meihua Wu Closes #8631 from rotationsymmetry/SPARK-9642. --- .../ml/regression/LinearRegression.scala | 164 +++++++++++------- .../ml/regression/LinearRegressionSuite.scala | 88 ++++++++++ project/MimaExcludes.scala | 8 +- 3 files changed, 191 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e4602d36ccc8..78a67c5fdab2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,21 +31,29 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.functions.{col, udf, lit} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization + with HasFitIntercept with HasStandardization with HasWeightCol + +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[regression] case class Instance(label: Double, weight: Double, features: Vector) /** * :: Experimental :: @@ -123,30 +131,43 @@ class LinearRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, statCounter) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new StatCounter))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), - (label: Double, features: Vector)) => - (summarizer.add(features), statCounter.merge(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), - (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => - (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) - }) - - val numFeatures = summarizer.mean.size - val yMean = statCounter.mean - val yStd = math.sqrt(statCounter.variance) + val (featuresSummarizer, ySummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val numFeatures = featuresSummarizer.mean.size + val yMean = ySummarizer.mean(0) + val yStd = math.sqrt(ySummarizer.variance(0)) // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -167,8 +188,8 @@ class LinearRegression(override val uid: String) return copyValues(model.setSummary(trainingSummary)) } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. @@ -318,7 +339,8 @@ class LinearRegressionModel private[ml] ( /** * :: Experimental :: - * Linear regression training results. + * Linear regression training results. Currently, the training summary ignores the + * training weights except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -477,7 +499,7 @@ class LinearRegressionSummary private[regression] ( * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param labelStd The standard deviation value of the label. * @param labelMean The mean value of the label. * @param fitIntercept Whether to fit an intercept term. @@ -485,7 +507,7 @@ class LinearRegressionSummary private[regression] ( * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( - weights: Vector, + coefficients: Vector, labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -493,26 +515,28 @@ private class LeastSquaresAggregator( featuresMean: Array[Double]) extends Serializable { private var totalCnt: Long = 0L + private var weightSum: Double = 0.0 private var lossSum = 0.0 - private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { - val weightsArray = weights.toArray.clone() + private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { + val coefficientsArray = coefficients.toArray.clone() var sum = 0.0 var i = 0 - val len = weightsArray.length + val len = coefficientsArray.length while (i < len) { if (featuresStd(i) != 0.0) { - weightsArray(i) /= featuresStd(i) - sum += weightsArray(i) * featuresMean(i) + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) } else { - weightsArray(i) = 0.0 + coefficientsArray(i) = 0.0 } i += 1 } - (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (coefficientsArray, offset, coefficientsArray.length) } - private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) private val gradientSumArray = Array.ofDim[Double](dim) @@ -520,30 +544,33 @@ private class LeastSquaresAggregator( * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The data point instance to be added. * @return This LeastSquaresAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") + def add(instance: Instance): this.type = + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + if (weight == 0.0) return this - if (diff != 0) { - val localGradientSumArray = gradientSumArray - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += diff * value / featuresStd(index) + val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / featuresStd(index) + } } + lossSum += weight * diff * diff / 2.0 } - lossSum += diff * diff / 2.0 - } - totalCnt += 1 - this - } + totalCnt += 1 + weightSum += weight + this + } /** * Merge another LeastSquaresAggregator, and update the loss and gradient @@ -557,8 +584,9 @@ private class LeastSquaresAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { + if (other.weightSum != 0) { totalCnt += other.totalCnt + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -574,11 +602,17 @@ private class LeastSquaresAggregator( def count: Long = totalCnt - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -589,7 +623,7 @@ private class LeastSquaresAggregator( * It's used in Breeze's convex optimization routines. */ private class LeastSquaresCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -598,17 +632,13 @@ private class LeastSquaresCostFun( featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { - val w = Vectors.fromBreeze(weights) + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeff = Vectors.fromBreeze(coefficients) - val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd, labelMean, fitIntercept, featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + seqOp = (aggregator, instance) => aggregator.add(instance), + combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) val totalGradientArray = leastSquaresAggregator.gradient.toArray @@ -616,7 +646,7 @@ private class LeastSquaresCostFun( 0.0 } else { var sum = 0.0 - w.foreachActive { (index, value) => + coeff.foreachActive { (index, value) => // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc73..8428f4f00b37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .zip(testSummary.residuals.select("residuals").collect()) .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } } + + test("linear regression with weighted samples"){ + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 814a11e588ce..b2e6be706637 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,10 +70,14 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ - Seq( + ) ++ Seq( ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this") ) case v if v.startsWith("1.5") => Seq( From b78c65b03ae87a3ba348c9d29ff4c296349eb49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?hushan=5B=E8=83=A1=E7=8F=8A=5D?= Date: Mon, 21 Sep 2015 14:26:15 -0500 Subject: [PATCH 048/896] [SPARK-5259] [CORE] don't submit stage until its dependencies map outputs are registered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track pending tasks by partition ID instead of Task objects. Before this change, failure & retry could result in a case where a stage got submitted before the map output from its dependencies get registered. This was due to an error in the condition for registering map outputs. Author: hushan[胡珊] Author: Imran Rashid Closes #7699 from squito/SPARK-5259. --- .../apache/spark/scheduler/DAGScheduler.scala | 12 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 197 ++++++++++++++++-- 4 files changed, 191 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c9a66e50440..394228b2728d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -944,7 +944,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { @@ -1060,8 +1060,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1152,7 +1152,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1198,7 +1198,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1242,7 +1242,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b37eccbd0f7b..a3829c319c48 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -66,7 +66,7 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 62af9031b9f8..c02597c4365c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,8 +487,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1c55f90ad9b4..6b5bcf0574de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -479,8 +479,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -490,7 +490,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -782,8 +782,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -1035,6 +1035,173 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we + * have completions from both the first & second attempt of stage 1. So all the map output is + * available before we finish any task set for stage 1. We want to make sure that we don't + * submit stage 2 until the map output for stage 1 is registered + */ + test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, + null, + createFakeTaskInfo(), + null)) + + // so we resubmit stage 0, which completes happily + scheduler.resubmitFailedStages() + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent( + task, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + runEvent(CompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + // late task finish from the first attempt + runEvent(CompletionEvent( + taskSets(1).tasks(2), + Success, + makeMapStatus("hostB", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) and also the newly available stage 1 + val stageToReduceIdxs = Seq( + 0 -> (0 until 3), + 1 -> (0 until 1) + ) + for { + (stage, reduceIdxs) <- stageToReduceIdxs + reduceIdx <- reduceIdxs + } { + // this would throw an exception if the map status hadn't been registered + val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + + /** + * We lose an executor after completing some shuffle map tasks on it. Those tasks get + * resubmitted, and when they finish the job completes normally + */ + test("register map outputs correctly after ExecutorLost and task Resubmitted") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) + submit(reduceRdd, Array(0)) + + // complete some of the tasks from the first stage, on one host + runEvent(CompletionEvent( + taskSets(0).tasks(0), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + + // now that host goes down + runEvent(ExecutorLost("exec-hostA")) + + // so we resubmit those tasks + runEvent(CompletionEvent( + taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + + // now complete everything on a different host + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)) + )) + + // now we should submit stage 1, and the map output from stage 0 should be registered + + // check that we have all the map output for stage 0 + (0 until reduceRdd.partitions.length).foreach { reduceIdx => + val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 1 has been submitted + assert(taskSets.size == 2) + val stage1TaskSet = taskSets(1) + assert(stage1TaskSet.stageId == 1) + assert(stage1TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -1393,8 +1560,8 @@ class DAGSchedulerSuite // Submit a map stage by itself submitMapStage(shuffleDep) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) assert(results.size === 1) results.clear() assertDataStructuresEmpty() @@ -1407,7 +1574,7 @@ class DAGSchedulerSuite // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) results.clear() @@ -1452,8 +1619,8 @@ class DAGSchedulerSuite // Complete the first stage assert(taskSets(0).stageId === 0) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", rdd1.partitions.size)), - (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -1461,7 +1628,7 @@ class DAGSchedulerSuite // When attempting the second stage, show a fetch failure assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (Success, makeMapStatus("hostA", rdd2.partitions.length)), (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -1469,7 +1636,7 @@ class DAGSchedulerSuite // Stage 0 should now be running as task set 2; make its task succeed assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( - (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -1477,8 +1644,8 @@ class DAGSchedulerSuite // Stage 1 should now be running as task set 3; make its first task succeed assert(taskSets(3).stageId === 1) complete(taskSets(3), Seq( - (Success, makeMapStatus("hostB", rdd2.partitions.size)), - (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + (Success, makeMapStatus("hostB", rdd2.partitions.length)), + (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) assert(listener2.results.size === 1) @@ -1494,7 +1661,7 @@ class DAGSchedulerSuite // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 assert(taskSets(5).stageId === 1) complete(taskSets(5), Seq( - (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + (Success, makeMapStatus("hostE", rdd2.partitions.length)))) complete(taskSets(6), Seq( (Success, 53))) assert(listener3.results === Map(0 -> 52, 1 -> 53)) From ba882db6f43dd2bc05675133158e4664ed07030a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:06:23 -0700 Subject: [PATCH 049/896] [SPARK-9769] [ML] [PY] add python api for countvectorizermodel From JIRA: Add Python API, user guide and example for ml.feature.CountVectorizerModel Author: Holden Karau Closes #8561 from holdenk/SPARK-9769-add-python-api-for-countvectorizermodel. --- python/pyspark/ml/feature.py | 148 +++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92db8df80280..f41d72f87725 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,12 +26,13 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', - 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', - 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] +__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', + 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -171,6 +172,141 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. + >>> df = sqlContext.createDataFrame( + ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ... ["label", "raw"]) + >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> model = cv.fit(df) + >>> model.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... + >>> sorted(map(str, model.vocabulary)) + ['a', 'b', 'c'] + """ + + # a placeholder to make it appear in the generated doc + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0") + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + + @keyword_only + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + """ + super(CountVectorizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", + self.uid) + self.minTF = Param( + self, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then " + + "this specifies a fraction (out of the document's token count). Note that the " + + "parameter is only used in transform of CountVectorizerModel and does not affect" + + "fitting. Default 1.0") + self.minDF = Param( + self, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of " + + "documents. Default 1.0") + self.vocabSize = Param( + self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + Set the params for the CountVectorizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + self._paramMap[self.minTF] = value + return self + + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + def setMinDF(self, value): + """ + Sets the value of :py:attr:`minDF`. + """ + self._paramMap[self.minDF] = value + return self + + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + def setVocabSize(self, value): + """ + Sets the value of :py:attr:`vocabSize`. + """ + self._paramMap[self.vocabSize] = value + return self + + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + def _create_model(self, java_model): + return CountVectorizerModel(java_model) + + +class CountVectorizerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by CountVectorizer. + """ + + @property + def vocabulary(self): + """ + An array of terms in the vocabulary. + """ + return self._call_java("vocabulary") + + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ From aeef44a3e32b53f7adecc8e9cfd684fb4598e87d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 21 Sep 2015 13:11:28 -0700 Subject: [PATCH 050/896] [SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing Implementation of significance testing using Streaming API. Author: Feynman Liang Author: Feynman Liang Closes #4716 from feynmanliang/ab_testing. --- .../examples/mllib/StreamingTestExample.scala | 90 +++++++ .../spark/mllib/stat/test/StreamingTest.scala | 145 +++++++++++ .../mllib/stat/test/StreamingTestMethod.scala | 167 ++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 22 ++ .../spark/mllib/stat/StreamingTestSuite.scala | 243 ++++++++++++++++++ 5 files changed, 667 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala new file mode 100644 index 000000000000..ab29f90254d3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * StreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object StreamingTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: StreamingTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => (label.toBoolean, value.toDouble) + }) + + val streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = streamingTest.registerStream(data) + out.print() + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala new file mode 100644 index 000000000000..75c6a51d0957 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. + * + * To address novelty affects, the `peacePeriod` specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The `windowSize` sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. + * + * Use a builder pattern to construct a streaming test in an application, for example: + * {{{ + * val model = new StreamingTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * }}} + */ +@Experimental +@Since("1.6.0") +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest + + /** Set the number of initial batches to ignore. Default: 0. */ + @Since("1.6.0") + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. Default: 0. + * A value of 0 will use all batches seen so far. + */ + @Since("1.6.0") + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. Default: "welch" */ + @Since("1.6.0") + def setTestMethod(method: String): this.type = { + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of (key,value) pairs where the key denotes group membership (true = + * experiment, false = control) and the value is the numerical metric to test for + * significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + + testMethod.doTest(pairedSummaries) + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + data.context.sparkContext.parallelize(Seq()) + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + if (this.windowSize == 0) { + data.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + data + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala new file mode 100644 index 000000000000..a7eaed51b4d5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] + */ +private[stat] sealed trait StreamingTestMethod extends Serializable { + + val methodName: String + val nullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform streaming 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] + + /** + * Implicit adapter to convert between streaming summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new StreamingTestResult( + tTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * @see http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new StreamingTestResult( + tTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. + * + * Currently supported tests: `welch`, `student`. + */ +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch"->WelchTTest, + "student"->StudentTTest) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index d01b3707be94..b0916d3e8465 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] ( "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for streaming testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class StreamingTestResult @Since("1.6.0") ( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "Streaming test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala new file mode 100644 index 000000000000..d3e9ef4ff079 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter +import org.apache.spark.util.random.XORShiftRandom + +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.methodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.methodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_._1)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + (false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} From 97a99dde6e8d69a4c4c135dc1d9b1520b2548b5b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 21 Sep 2015 13:15:44 -0700 Subject: [PATCH 051/896] [SPARK-10676] [DOCS] Add documentation for SASL encryption options. Author: Marcelo Vanzin Closes #8803 from vanzin/SPARK-10676. --- docs/configuration.md | 16 ++++++++++++++++ docs/security.md | 22 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index b22587c70316..284f97ad09ec 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1285,6 +1285,22 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + spark.authenticate.enableSaslEncryption + false + + Enable encrypted communication when authentication is enabled. This option is currently + only supported by the block transfer service. + + + + spark.network.sasl.serverAlwaysEncrypt + false + + Disable unencrypted connections for services that support SASL authentication. This is + currently supported by the external shuffle service. + + spark.core.connection.ack.wait.timeout 60s diff --git a/docs/security.md b/docs/security.md index d4ffa60e59a3..177109415180 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,9 +23,16 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +supported for the block transfer service. Encryption is not yet supported for the WebUI. -Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle +files, cached data, and other application files. If encrypting this data is desired, a workaround is +to configure your cluster manager to store application data on encrypted disks. + +### SSL Configuration + +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -47,6 +54,17 @@ follows: * Import all exported public keys into a single trust-store * Distribute the trust-store over the nodes +### Configuring SASL Encryption + +SASL encryption is currently supported for the block transfer service when authentication +(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set +`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + +When using an external shuffle service, it's possible to disable unencrypted connections by setting +`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that +option is enabled, applications that are not set up to use SASL encryption will fail to connect to +the shuffle service. + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight From 362539f8d97f6bb67f0d0983f7dea36b77cc9d18 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:33:10 -0700 Subject: [PATCH 052/896] [SPARK-10630] [SQL] Add a createDataFrame API that takes in a java list It would be nice to support creating a DataFrame directly from a Java List of Row. Author: Holden Karau Closes #8779 from holdenk/SPARK-10630-create-DataFrame-from-Java-List. --- .../scala/org/apache/spark/sql/SQLContext.scala | 14 ++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f099940800cc..1bd4e26fb316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -476,6 +476,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 1.6.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + /** * Applies a schema to an RDD of Java Beans. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5f9abd4999ce..250ac2e1092d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -37,6 +37,7 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -181,6 +182,15 @@ public void testCreateDataFrameFromJavaBeans() { } } + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + private static final Comparator crosstabRowComparator = new Comparator() { @Override public int compare(Row row1, Row row2) { From 7c4f852bfc39537840f56cd8121457a0dc1ad7c1 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 21 Sep 2015 14:24:19 -0700 Subject: [PATCH 053/896] [DOC] [PYSPARK] [MLLIB] Added newlines to docstrings to fix parameter formatting Added newlines before `:param ...:` and `:return:` markup. Without these, parameter lists aren't formatted correctly in the API docs. I.e: ![screen shot 2015-09-21 at 21 49 26](https://cloud.githubusercontent.com/assets/11915197/10004686/de3c41d4-60aa-11e5-9c50-a46dcb51243f.png) .. looks like this once newline is added: ![screen shot 2015-09-21 at 21 50 14](https://cloud.githubusercontent.com/assets/11915197/10004706/f86bfb08-60aa-11e5-8524-ae4436713502.png) Author: noelsmith Closes #8851 from noel-smith/docstring-missing-newline-fix. --- python/pyspark/ml/param/__init__.py | 4 ++++ python/pyspark/ml/pipeline.py | 1 + python/pyspark/ml/tuning.py | 2 ++ python/pyspark/ml/wrapper.py | 2 ++ python/pyspark/mllib/evaluation.py | 2 +- python/pyspark/mllib/linalg/__init__.py | 1 + python/pyspark/streaming/context.py | 2 ++ python/pyspark/streaming/mqtt.py | 1 + 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index eeeac49b2198..2e0c63cb47b1 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -164,6 +164,7 @@ def extractParamMap(self, extra=None): a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra. + :param extra: extra param values :return: merged param map """ @@ -182,6 +183,7 @@ def copy(self, extra=None): embedded and extra parameters over and returns the copy. Subclasses should override this method if the default approach is not sufficient. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -201,6 +203,7 @@ def _shouldOwn(self, param): def _resolveParam(self, param): """ Resolves a param and validates the ownership. + :param param: param name or the param instance, which must belong to this Params instance :return: resolved param instance @@ -243,6 +246,7 @@ def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. + :param to: the target instance :param extra: extra params to be copied :return: the target instance with param values copied diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 13cf2b0f7bbd..312a8502b3a2 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -154,6 +154,7 @@ def __init__(self, stages=None): def setStages(self, value): """ Set pipeline stages. + :param value: a list of transformers or estimators :return: the pipeline instance """ diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ab5621f45c72..705ee5368575 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -254,6 +254,7 @@ def copy(self, extra=None): Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -290,6 +291,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 8218c7c5f801..4bcb4aaec89d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -119,6 +119,7 @@ def _create_model(self, java_model): def _fit_java(self, dataset): """ Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) @@ -173,6 +174,7 @@ def copy(self, extra=None): extra params. This implementation first calls Params.copy and then make a copy of the companion Java model with extra params. So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4398ca86f2ec..a90e5c50e54b 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -147,7 +147,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels an RDD of (prediction, label) pairs. + :param predictionAndLabels: an RDD of (prediction, label) pairs. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index f929e3e96fbe..ea42127f1651 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -240,6 +240,7 @@ class Vector(object): def toArray(self): """ Convert the vector into an numpy.ndarray + :return: numpy.ndarray """ raise NotImplementedError diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4069d7a14998..a8c9ffc235b9 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -240,6 +240,7 @@ def start(self): def awaitTermination(self, timeout=None): """ Wait for the execution to stop. + @param timeout: time to wait in seconds """ if timeout is None: @@ -252,6 +253,7 @@ def awaitTerminationOrTimeout(self, timeout): Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds """ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index f06598971c54..fa83006c36db 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -31,6 +31,7 @@ def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object :param brokerUrl: Url of remote mqtt publisher :param topic: topic name to subscribe to From 72869883f12b6e0a4e5aad79c0ac2cfdb4d83f09 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Sep 2015 16:47:52 -0700 Subject: [PATCH 054/896] [SPARK-10649] [STREAMING] Prevent inheriting job group and irrelevant job description in streaming jobs The job group, and job descriptions information is passed through thread local properties, and get inherited by child threads. In case of spark streaming, the streaming jobs inherit these properties from the thread that called streamingContext.start(). This may not make sense. 1. Job group: This is mainly used for cancelling a group of jobs together. It does not make sense to cancel streaming jobs like this, as the effect will be unpredictable. And its not a valid usecase any way, to cancel a streaming context, call streamingContext.stop() 2. Job description: This is used to pass on nice text descriptions for jobs to show up in the UI. The job description of the thread that calls streamingContext.start() is not useful for all the streaming jobs, as it does not make sense for all of the streaming jobs to have the same description, and the description may or may not be related to streaming. The solution in this PR is meant for the Spark master branch, where local properties are inherited by cloning the properties. The job group and job description in the thread that starts the streaming scheduler are explicitly removed, so that all the subsequent child threads does not inherit them. Also, the starting is done in a new child thread, so that setting the job group and description for streaming, does not change those properties in the thread that called streamingContext.start(). Author: Tathagata Das Closes #8781 from tdas/SPARK-10649. --- .../org/apache/spark/util/ThreadUtils.scala | 59 +++++++++++++++++++ .../apache/spark/util/ThreadUtilsSuite.scala | 24 +++++++- .../spark/streaming/StreamingContext.scala | 15 ++++- .../streaming/StreamingContextSuite.scala | 32 ++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3..22e291a2b48d 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -21,6 +21,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -86,4 +87,62 @@ private[spark] object ThreadUtils { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() Executors.newSingleThreadScheduledExecutor(threadFactory) } + + /** + * Run a piece of code in a new thread and return the result. Exception in the new thread is + * thrown in the caller thread with an adjusted stack trace that removes references to this + * method for clarity. The exception stack traces will be like the following + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def runInNewThread[T]( + threadName: String, + isDaemon: Boolean = true)(body: => T): T = { + @volatile var exception: Option[Throwable] = None + @volatile var result: T = null.asInstanceOf[T] + + val thread = new Thread(threadName) { + override def run(): Unit = { + try { + result = body + } catch { + case NonFatal(e) => + exception = Some(e) + } + } + } + thread.setDaemon(isDaemon) + thread.start() + thread.join() + + exception match { + case Some(realException) => + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). + val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace.takeWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement( + s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", + " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + throw realException + case None => + result + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 8c51e6b14b7f..620e4debf4e0 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.Random import org.apache.spark.SparkFunSuite @@ -66,4 +67,25 @@ class ThreadUtilsSuite extends SparkFunSuite { val futureThreadName = Await.result(f, 10.seconds) assert(futureThreadName === callerThreadName) } + + test("runInNewThread") { + import ThreadUtils._ + assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") + assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) + assert( + runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false + ) + val uniqueExceptionMessage = "test" + Random.nextInt() + val exception = intercept[IllegalArgumentException] { + runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } + } + assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) + assert(exception.getStackTrace.mkString("\n").contains( + "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, + "stack trace does not contain expected place holder" + ) + assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + "stack trace contains unexpected references to ThreadUtils" + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b496d1f341a0..6720ba4f72cf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -588,12 +588,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -618,6 +626,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index d26894e88fc2..3b9d0d15ea04 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound === null) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) From 0494c80ef54f6f3a8c6f2d92abfe1a77a91df8b0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 21 Sep 2015 18:06:45 -0700 Subject: [PATCH 055/896] [SPARK-10495] [SQL] Read date values in JSON data stored by Spark 1.5.0. https://issues.apache.org/jira/browse/SPARK-10681 Author: Yin Huai Closes #8806 from yhuai/SPARK-10495. --- .../datasources/json/JacksonGenerator.scala | 36 ++++++ .../datasources/json/JacksonParser.scala | 15 ++- .../datasources/json/JsonSuite.scala | 103 +++++++++++++++++- 3 files changed, 152 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index f65c7bbd6e29..23bada1ddd92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -73,6 +73,38 @@ private[sql] object JacksonGenerator { valWriter(field.dataType, v) } gen.writeEndObject() + + // For UDT, udt.serialize will produce SQL types. So, we need the following three cases. + case (ArrayType(ty, _), v: ArrayData) => + gen.writeStartArray() + v.foreach(ty, (_, value) => valWriter(ty, value)) + gen.writeEndArray() + + case (MapType(kt, vt, _), v: MapData) => + gen.writeStartObject() + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) + gen.writeEndObject() + + case (StructType(ty), v: InternalRow) => + gen.writeStartObject() + var i = 0 + while (i < ty.length) { + val field = ty(i) + val value = v.get(i, field.dataType) + if (value != null) { + gen.writeFieldName(field.name) + valWriter(field.dataType, value) + } + i += 1 + } + gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) @@ -133,6 +165,10 @@ private[sql] object JacksonGenerator { i += 1 } gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index ff4d8c04e8ea..c51140749c8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -62,10 +62,23 @@ private[sql] object JacksonParser { // guard the non string type null + case (VALUE_STRING, BinaryType) => + parser.getBinaryValue + case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + val stringValue = parser.getText + if (stringValue.contains("-")) { + // The format of this string will probably be "yyyy-mm-dd". + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + } else { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } case (VALUE_STRING, TimestampType) => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6a18cc6d2713..b614e6c4148f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType @@ -1159,4 +1159,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } } From c986e933a900602af47966bd41edb2116c421a39 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 21 Sep 2015 21:09:59 -0700 Subject: [PATCH 056/896] [SPARK-10711] [SPARKR] Do not assume spark.submit.deployMode is always set In ```RUtils.sparkRPackagePath()``` we 1. Call ``` sys.props("spark.submit.deployMode")``` which returns null if ```spark.submit.deployMode``` is not suet 2. Call ``` sparkConf.get("spark.submit.deployMode")``` which throws ```NoSuchElementException``` if ```spark.submit.deployMode``` is not set. This patch simply passes a default value ("cluster") for ```spark.submit.deployMode```. cc rxin Author: Hossein Closes #8832 from falaki/SPARK-10711. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 9e807cc52f18..fd5646b5b637 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -44,7 +44,7 @@ private[spark] object RUtils { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" From 1cd67415728e660a90e4dbe136272b5d6b8f1142 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 23:21:24 -0700 Subject: [PATCH 057/896] [SPARK-9821] [PYSPARK] pyspark-reduceByKey-should-take-a-custom-partitioner from the issue: In Scala, I can supply a custom partitioner to reduceByKey (and other aggregation/repartitioning methods like aggregateByKey and combinedByKey), but as far as I can tell from the Pyspark API, there's no way to do the same in Python. Here's an example of my code in Scala: weblogs.map(s => (getFileType(s), 1)).reduceByKey(new FileTypePartitioner(),_+_) But I can't figure out how to do the same in Python. The closest I can get is to call repartition before reduceByKey like so: weblogs.map(lambda s: (getFileType(s), 1)).partitionBy(3,hash_filetype).reduceByKey(lambda v1,v2: v1+v2).collect() But that defeats the purpose, because I'm shuffling twice instead of once, so my performance is worse instead of better. Author: Holden Karau Closes #8569 from holdenk/SPARK-9821-pyspark-reduceByKey-should-take-a-custom-partitioner. --- python/pyspark/rdd.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 73d7d9a5692a..56e892243c79 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -686,7 +686,7 @@ def cartesian(self, other): other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None): + def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ Return an RDD of grouped items. @@ -695,7 +695,7 @@ def groupBy(self, f, numPartitions=None): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) @ignore_unicode_prefix def pipe(self, command, env=None, checkCode=False): @@ -1539,22 +1539,23 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numPartitions} partitions, or + Output will be partitioned with C{numPartitions} partitions, or the default parallelism level if C{numPartitions} is not specified. + Default partitioner is hash-partition. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numPartitions) + return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): """ @@ -1739,7 +1740,7 @@ def add_shuffle_key(split, iterator): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numPartitions=None): + numPartitions=None, partitionFunc=portable_hash): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -1777,7 +1778,7 @@ def combineLocally(iterator): return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) @@ -1786,7 +1787,8 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) - def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, + partitionFunc=portable_hash): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -1800,9 +1802,9 @@ def createZero(): return copy.deepcopy(zeroValue) return self.combineByKey( - lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc) - def foldByKey(self, zeroValue, func, numPartitions=None): + def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -1817,13 +1819,14 @@ def foldByKey(self, zeroValue, func, numPartitions=None): def createZero(): return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions, + partitionFunc) def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None): + def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -1859,7 +1862,7 @@ def combine(iterator): return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer) From bf20d6c9f9e478a5de24b45bbafd4dd89666c4cf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Sep 2015 23:29:59 -0700 Subject: [PATCH 058/896] [SPARK-10716] [BUILD] spark-1.5.0-bin-hadoop2.6.tgz file doesn't uncompress on OS X due to hidden file Remove ._SUCCESS.crc hidden file that may cause problems in distribution tar archive, and is not used Author: Sean Owen Closes #8846 from srowen/SPARK-10716. --- .../test_support/sql/orc_partitioned/._SUCCESS.crc | Bin 8 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/test_support/sql/orc_partitioned/._SUCCESS.crc diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc deleted file mode 100644 index 3b7b044936a890cd8d651d349a752d819d71d22c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8 PcmYc;N@ieSU}69O2$TUk From 0180b849dbaf191826231eda7dfaaf146a19602b Mon Sep 17 00:00:00 2001 From: Jian Feng Date: Mon, 21 Sep 2015 23:36:41 -0700 Subject: [PATCH 059/896] [SPARK-10577] [PYSPARK] DataFrame hint for broadcast join https://issues.apache.org/jira/browse/SPARK-10577 Author: Jian Feng Closes #8801 from Jianfeng-chs/master. --- python/pyspark/sql/functions.py | 9 +++++++++ python/pyspark/sql/tests.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26b8662718a6..fa04f4cd83b6 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -29,6 +29,7 @@ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import DataFrame def _create_function(name, doc=""): @@ -189,6 +190,14 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@since(1.6) +def broadcast(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3e680f1030a7..645133b2b2d8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1075,6 +1075,24 @@ def foo(): self.assertRaises(TypeError, foo) + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() + class HiveContextSQLTests(ReusedPySparkTestCase): From 781b21ba2a873ed29394c8dbc74fc700e3e0d17e Mon Sep 17 00:00:00 2001 From: Ewan Leith Date: Mon, 21 Sep 2015 23:43:20 -0700 Subject: [PATCH 060/896] [SPARK-10419] [SQL] Adding SQLServer support for datetimeoffset types to JdbcDialects Reading from Microsoft SQL Server over jdbc fails when the table contains datetimeoffset types. This patch registers a SQLServer JDBC Dialect that maps datetimeoffset to a String, as Microsoft suggest. Author: Ewan Leith Closes #8575 from realitymine-coordinator/sqlserver. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 68ebaaca6c53..c70fea1c3f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -137,6 +137,8 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -260,3 +262,19 @@ case object DB2Dialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. + */ +@DeveloperApi +case object MsSqlServerDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Some(StringType) + } else None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5ab9381de4d6..c4b039a9c535 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -408,6 +408,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } From 1fcefef06950e2f03477282368ca835bbf40ff24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Sep 2015 23:46:00 -0700 Subject: [PATCH 061/896] [SPARK-10446][SQL] Support to specify join type when calling join with usingColumns JIRA: https://issues.apache.org/jira/browse/SPARK-10446 Currently the method `join(right: DataFrame, usingColumns: Seq[String])` only supports inner join. It is more convenient to have it support other join types. Author: Liang-Chi Hsieh Closes #8600 from viirya/usingcolumns_df. --- python/pyspark/sql/dataframe.py | 6 ++++- .../org/apache/spark/sql/DataFrame.scala | 22 ++++++++++++++++++- .../apache/spark/sql/DataFrameJoinSuite.scala | 13 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fb995fa3a76b..80f8d8a0eb37 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -567,7 +567,11 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): - jdf = self._jdf.join(other._jdf, self._jseq(on)) + if how is None: + jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8f737c202393..ba94d77b2e60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -484,6 +484,26 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @group dfops + * @since 1.6.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( @@ -502,7 +522,7 @@ class DataFrame private[sql]( Join( joined.left, joined.right, - joinType = Inner, + joinType = JoinType(joinType), condition) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e2716d7841d8..56ad71ea4f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") From f24316e6d928c263cbf3872edd97982059c3db22 Mon Sep 17 00:00:00 2001 From: Madhusudanan Kandasamy Date: Tue, 22 Sep 2015 00:03:48 -0700 Subject: [PATCH 062/896] [SPARK-10458] [SPARK CORE] Added isStopped() method in SparkContext Added isStopped() method in SparkContext Author: Madhusudanan Kandasamy Closes #8749 from kmadhugit/SPARK-10458. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ebd8e946ee7a..967fec9f42bc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus From fd61b004877ba4d51c95cd0e08f53bffdf106395 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 00:05:30 -0700 Subject: [PATCH 063/896] [Minor] style fix for previous commit f24316e --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 967fec9f42bc..bf3aeb488d59 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** * @return true if context is stopped or in the midst of stopping. */ From 4da32bc0e747fefe847bffe493785d4d16069c04 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 00:07:30 -0700 Subject: [PATCH 064/896] [SPARK-8567] [SQL] Increase the timeout of o.a.s.sql.hive.HiveSparkSubmitSuite to 5 minutes. https://issues.apache.org/jira/browse/SPARK-8567 Looks like "SPARK-8368: includes jars passed in through --jars" is pretty flaky now. Based on some history runs, the time spent on a successful run may be from 1.5 minutes to almost 3 minutes. Let's try to increase the timeout and see if we can fix this test. https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.5-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.0,label=spark-test/385/testReport/junit/org.apache.spark.sql.hive/HiveSparkSubmitSuite/SPARK_8368__includes_jars_passed_in_through___jars/history/?start=25 Author: Yin Huai Closes #8850 from yhuai/SPARK-8567-anotherTry. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 97df249bdb6d..5f1660b62d41 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -139,7 +139,7 @@ class HiveSparkSubmitSuite new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180.seconds) { process.waitFor() } + val exitCode = failAfter(300.seconds) { process.waitFor() } if (exitCode != 0) { // include logs in output. Note that logging is async and may not have completed // at the time this exception is raised From f3b727c801408b1cd50e5d9463f2fe0fce654a16 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 00:09:29 -0700 Subject: [PATCH 065/896] [SQL] [MINOR] map -> foreach. DataFrame.explain should use foreach to print the explain content. Author: Reynold Xin Closes #8862 from rxin/map-foreach. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ba94d77b2e60..a11140b71736 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -320,9 +320,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ def explain(extended: Boolean): Unit = { - ExplainCommand( - queryExecution.logical, - extended = extended).queryExecution.executedPlan.executeCollect().map { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + explain.queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println From 0bd0e5bed2176b119b3ada590993e153757ea09b Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 22 Sep 2015 00:14:27 -0700 Subject: [PATCH 066/896] =?UTF-8?q?[SPARK-10695]=20[DOCUMENTATION]=20[MESO?= =?UTF-8?q?S]=20Fixing=20incorrect=20value=20informati=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …on for spark.mesos.constraints parameter. Author: Akash Mishra Closes #8816 from SleepyThread/constraint-fix. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 460a66f37dd6..ec5a44d79212 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -189,10 +189,10 @@ using `conf.set("spark.cores.max", "10")` (for example). You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support From 7278f792a73bbcf8d68f38dc2d87cf722693c4cf Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Tue, 22 Sep 2015 11:03:21 +0100 Subject: [PATCH 067/896] [SPARK-10718] [BUILD] Update License on conf files and corresponding excludes file update Update License on conf files and corresponding excludes file update Author: Rekha Joshi Author: Joshi Closes #8842 from rekhajoshm/SPARK-10718. --- .rat-excludes | 12 ------------ conf/docker.properties.template | 17 +++++++++++++++++ conf/fairscheduler.xml.template | 18 ++++++++++++++++++ conf/log4j.properties.template | 17 +++++++++++++++++ conf/metrics.properties.template | 17 +++++++++++++++++ conf/slaves.template | 17 +++++++++++++++++ conf/spark-defaults.conf.template | 17 +++++++++++++++++ conf/spark-env.sh.template | 17 +++++++++++++++++ .../spark/log4j-defaults-repl.properties | 17 +++++++++++++++++ .../org/apache/spark/log4j-defaults.properties | 17 +++++++++++++++++ 10 files changed, 154 insertions(+), 12 deletions(-) diff --git a/.rat-excludes b/.rat-excludes index 9165872b9fb2..08fba6d351d6 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9..55cb094b4af4 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a3598..385b2e772d2c 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 74c5cea94403..f3046be54d7c 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f..d6962e0da2f3 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20..be42a638230b 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e136..19cba6e71ed1 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c05fe381a36a..990ded420be7 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index 689afea64f8d..c85abc35b93b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=WARN, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 27006e45e932..d44cc85dcbd8 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender From 870b8a2edd44c9274c43ca0db4ef5b0998e16fd8 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Tue, 22 Sep 2015 11:05:24 +0100 Subject: [PATCH 068/896] [SPARK-10706] [MLLIB] Add java wrapper for random vector rdd Add java wrapper for random vector rdd holdenk srowen Author: Meihua Wu Closes #8841 from rotationsymmetry/SPARK-10706. --- .../spark/mllib/random/RandomRDDs.scala | 42 +++++++++++++++++++ .../mllib/random/JavaRandomRDDsSuite.java | 17 ++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index f8ff26b5795b..41d7c4d355f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -855,6 +855,48 @@ object RandomRDDs { sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fce5f6712f46..5728df5aeebd 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -246,6 +246,23 @@ public void testArbitrary() { Assert.assertEquals(2, rdd.first().length()); } } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } } // This is just a test generator, it always returns a string of 42 From f4a3c4e34ce93bcaf29c0a35573932880a8b792b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Sep 2015 10:19:08 -0700 Subject: [PATCH 069/896] [SPARK-9962] [ML] Decision Tree training: prevNodeIdsForInstances.unpersist() at end of training NodeIdCache: prevNodeIdsForInstances.unpersist() needs to be called at end of training. Author: Holden Karau Closes #8541 from holdenk/SPARK-9962-decission-tree-training-prevNodeIdsForiNstances-unpersist-at-end-of-training. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++---- .../org/apache/spark/mllib/tree/impl/NodeIdCache.scala | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 488e8e4fb5dc..c5ad8df73fac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -164,10 +164,10 @@ private[spark] class NodeIdCache( } } } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 8f9eb24b57b5..0abed5411143 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -166,6 +166,10 @@ private[spark] class NodeIdCache( fs.delete(new Path(old.getCheckpointFile.get), true) } } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } From 7104ee0e5dc1290b8b845a0a8ddcdb1875cfd060 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Sep 2015 11:00:33 -0700 Subject: [PATCH 070/896] [SPARK-10750] [ML] ML Param validate should print better error information Currently when you set illegal value for params of array type (such as IntArrayParam, DoubleArrayParam, StringArrayParam), it will throw IllegalArgumentException but with incomprehensible error information. Take ```VectorSlicer.setNames``` as an example: ```scala val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") // The value of setNames must be contain distinct elements, so the next line will throw exception. vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4", "f1")) ``` It will throw IllegalArgumentException as: ``` vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. java.lang.IllegalArgumentException: vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. ``` We should distinguish the value of array type from primitive type at Param.validate(value: T), and we will get better error information. ``` vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. java.lang.IllegalArgumentException: vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. ``` Author: Yanbo Liang Closes #8863 from yanboliang/spark-10750. --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index de32b7218c27..48f6269e57e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -65,7 +65,12 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + val valueToString = value match { + case v: Array[_] => v.mkString("[", ",", "]") + case _ => value.toString + } + throw new IllegalArgumentException( + s"$parent parameter $name given invalid value $valueToString.") } } From 2ea0f2e11b82ef4817c7e6a162ea23da7860b893 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 22 Sep 2015 11:01:32 -0700 Subject: [PATCH 071/896] [SPARK-9585] Delete the input format caching because some input format are non thread safe If we cache the InputFormat, all tasks on the same executor will share it. Some InputFormat is thread safety, but some are not, such as HiveHBaseTableInputFormat. If tasks share a non thread safe InputFormat, unexpected error may be occurs. To avoid it, I think we should delete the input format caching. Author: xutingjun Author: meiyoula <1039320815@qq.com> Author: Xutingjun Closes #7918 from XuTingjun/cached_inputFormat. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8f2655d63b79..77b57132b9f1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -182,17 +182,11 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] if (newInputFormat.isInstanceOf[Configurable]) { newInputFormat.asInstanceOf[Configurable].setConf(conf) } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } From 22d40159e60dd27a428e4051ef607292cbffbff3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 22 Sep 2015 11:07:01 -0700 Subject: [PATCH 072/896] [SPARK-10593] [SQL] fix resolve output of Generate The output of Generate should not be resolved as Reference. Author: Davies Liu Closes #8755 from davies/view. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 1 - .../catalyst/plans/logical/basicOperators.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 14 ++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02f34cbf58ad..bf72d47ce1ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -378,6 +378,22 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by + // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. + case g @ Generate(generator, join, outer, qualifier, output, child) + if child.resolved && !generator.resolved => + val newG = generator transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldExpr) => + ExtractValue(child, fieldExpr, resolver) + } + if (newG.fastEquals(generator)) { + g + } else { + Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 55286f9f2fc5..0ec9f0857108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 722f69cdca82..ae9482c10f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -68,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8126d0233521..bb02473dd17c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1170,4 +1170,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sqlContext.table("`db.t`"), df) } } + + test("SPARK-10593 same column names in lateral view") { + val df = sqlContext.sql( + """ + |select + |insideLayer2.json as a2 + |from (select '{"layer1": {"layer2": "text inside layer 2"}}' json) test + |lateral view json_tuple(json, 'layer1') insideLayer1 as json + |lateral view json_tuple(insideLayer1.json, 'layer2') insideLayer2 as json + """.stripMargin + ) + + checkAnswer(df, Row("text inside layer 2") :: Nil) + } } From 1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 11:50:22 -0700 Subject: [PATCH 073/896] [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader. Author: Josh Rosen Closes #8825 from JoshRosen/shuffle-reader-cleanup. --- ...shShuffleReader.scala => BlockStoreShuffleReader.scala} | 5 ++--- .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- ...eaderSuite.scala => BlockStoreShuffleReaderSuite.scala} | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{hash/HashShuffleReader.scala => BlockStoreShuffleReader.scala} (97%) rename core/src/test/scala/org/apache/spark/shuffle/{hash/HashShuffleReaderSuite.scala => BlockStoreShuffleReaderSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1..6dc9a16e5853 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b46..d2e2fc4c110a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303da..9df4e551669c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 96% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 05b3afef5b83..a5eafb1b5529 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -134,7 +133,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, From 5017c685f484ec256101d1d33bad11d9e0c0f641 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Sep 2015 12:14:15 -0700 Subject: [PATCH 074/896] [SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan Closes #8858 from cloud-fan/non-deter. --- .../sql/catalyst/optimizer/Optimizer.scala | 69 ++++++++++++++----- .../optimizer/SetOperationPushDownSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 41 +++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 324f40a051c3..63602eaa8ccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] { * Intersect: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. */ -object SetOperationPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * Splits the condition expression into small conditions by `And`, and partition them by + * deterministic, and finally recombine them by `And`. It returns an expression containing + * all deterministic expressions (the first field of the returned Tuple2) and an expression + * containing all non-deterministic expressions (the second field of the returned Tuple2). + */ + private def partitionByDeterministic(condition: Expression): (Expression, Expression) = { + val andConditions = splitConjunctivePredicates(condition) + andConditions.partition(_.deterministic) match { + case (deterministic, nondeterministic) => + deterministic.reduceOption(And).getOrElse(Literal(true)) -> + nondeterministic.reduceOption(And).getOrElse(Literal(true)) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union case Filter(condition, u @ Union(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(u) - Union( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection through UNION ALL - case Project(projectList, u @ Union(left, right)) => - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Union( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, u @ Union(left, right)) => + if (projectList.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } else { + p + } // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(i) - Intersect( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Intersect( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(e) - Except( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Except( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 3fca47a023dc..1595ad932742 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - SetOperationPushDown) :: Nil + SetOperationPushDown, + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1370713975f2..d919877746c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(intersect.count() === 30) assert(except.count() === 70) } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } } From 2204cdb28483b249616068085d4e88554fe6acef Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:29:39 -0700 Subject: [PATCH 075/896] [SPARK-10672] [SQL] Do not fail when we cannot save the metadata of a data source table in a hive compatible way https://issues.apache.org/jira/browse/SPARK-10672 With changes in this PR, we will fallback to same the metadata of a table in Spark SQL specific way if we fail to save it in a hive compatible way (Hive throws an exception because of its internal restrictions, e.g. binary and decimal types cannot be saved to parquet if the metastore is running Hive 0.13). I manually tested the fix with the following test in `DataSourceWithHiveMetastoreCatalogSuite` (`spark.sql.hive.metastore.version=0.13` and `spark.sql.hive.metastore.jars`=`maven`). ``` test(s"fail to save metadata of a parquet table in hive 0.13") { withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, cast("val_1" as binary) AS d2 """.stripMargin) sql( s"""describe formatted t """.stripMargin).collect.foreach(println) sqlContext.table("t").show } } } } ``` Without this fix, we will fail with the following error. ``` org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:619) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:576) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply$mcV$sp(ClientWrapper.scala:359) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:256) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:211) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:248) at org.apache.spark.sql.hive.client.ClientWrapper.createTable(ClientWrapper.scala:357) at org.apache.spark.sql.hive.HiveMetastoreCatalog.createDataSourceTable(HiveMetastoreCatalog.scala:358) at org.apache.spark.sql.hive.execution.CreateMetastoreDataSourceAsSelect.run(commands.scala:285) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult$lzycompute(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.doExecute(commands.scala:69) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:140) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:138) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:138) at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:58) at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:58) at org.apache.spark.sql.DataFrame.(DataFrame.scala:144) at org.apache.spark.sql.DataFrame.(DataFrame.scala:129) at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:725) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2$$anonfun$apply$2.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:165) at org.apache.spark.sql.test.SQLTestUtils$class.withTable(SQLTestUtils.scala:150) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:162) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.test.SQLTestUtils$class.withTempPath(SQLTestUtils.scala:125) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTempPath(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.scalatest.FunSuite.runTest(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:318) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.org$scalatest$BeforeAndAfterAll$$super$run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:462) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:671) at sbt.ForkMain$Run$2.call(ForkMain.java:294) at sbt.ForkMain$Run$2.call(ForkMain.java:284) at java.util.concurrent.FutureTask.run(FutureTask.java:262) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.getObjectInspector(ArrayWritableObjectInspector.java:108) at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.(ArrayWritableObjectInspector.java:60) at org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe.initialize(ParquetHiveSerDe.java:113) at org.apache.hadoop.hive.metastore.MetaStoreUtils.getDeserializer(MetaStoreUtils.java:339) at org.apache.hadoop.hive.ql.metadata.Table.getDeserializerFromMetaStore(Table.java:288) at org.apache.hadoop.hive.ql.metadata.Table.checkValidity(Table.java:194) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:597) ... 76 more ``` Author: Yin Huai Closes #8824 from yhuai/datasourceMetadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 101 +++++++++--------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0c1b41e3377e..012634cb5aeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -309,69 +309,68 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } // TODO: Support persisting partitioned data source relations in Hive compatible format - val hiveTable = (maybeSerDe, dataSource.relation) match { + val qualifiedTableName = tableIdent.quotedString + val (hiveCompitiableTable, logMessage) = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) - if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - // Hive ParquetSerDe doesn't support decimal type until 1.2.0. - val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) - val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) - - val hiveParquetSupportsDecimal = client.version match { - case org.apache.spark.sql.hive.client.hive.v1_2 => true - case _ => false - } - - if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { - // If Hive version is below 1.2.0, we cannot save Hive compatible schema to - // metastore when the file format is Parquet and the schema has DecimalType. - logWarning { - "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + - "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + - s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." - } - newSparkSQLSpecificMetastoreTable() - } else { - logInfo { - "Persisting data source relation with a single input path into Hive metastore in " + - s"Hive compatible format. Input path: ${relation.paths.head}" - } - newHiveCompatibleMetastoreTable(relation, serde) - } + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) + val message = + s"Persisting data source relation $qualifiedTableName with a single input path " + + s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + (Some(hiveTable), message) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => - logWarning { - "Persisting partitioned data source relation into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), relation: HadoopFsRelation) => - logWarning { - "Persisting data source relation with multiple input paths into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting data source relation $qualifiedTableName with multiple input paths into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + s"Input paths: " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), _) => - logWarning { - s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + - "Persisting it into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Data source relation $qualifiedTableName is not a " + + s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " + + "in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) case _ => - logWarning { + val message = s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - "Persisting data source relation into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + s"Persisting data source relation $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) } - client.createTable(hiveTable) + (hiveCompitiableTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatiable way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + client.createTable(table) + } catch { + case throwable: Throwable => + val warningMessage = + s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " + + s"it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, throwable) + val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() + client.createTable(sparkSqlSpecificTable) + } + + case (None, message) => + logWarning(message) + val hiveTable = newSparkSQLSpecificMetastoreTable() + client.createTable(hiveTable) + } } def hiveDefaultTableFilePath(tableName: String): String = { From 5aea987c904b281d7952ad8db40a32561b4ec5cf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:31:35 -0700 Subject: [PATCH 076/896] [SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai Closes #8854 from yhuai/SMJBug. --- .../codegen/GenerateProjection.scala | 2 ++ .../apache/spark/sql/execution/Window.scala | 9 ++++-- .../sql/execution/joins/SortMergeJoin.scala | 25 +++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 28 +++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2164ddf03d1b..75524b568d68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { + // GenerateProjection does not work with UnsafeRows. + assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 0269d6d4b7a1..f8929530c503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -253,7 +253,11 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(partitionSpec, child.output) + val grouping = if (child.outputsUnsafeRows) { + UnsafeProjection.create(partitionSpec, child.output) + } else { + newProjection(partitionSpec, child.output) + } // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow @@ -277,7 +281,8 @@ case class Window( val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - val currentGroup = nextGroup + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() rows = new CompactBuffer while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 906f20d2a728..70a1af6a7063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,9 +56,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - protected[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(leftKeys) @@ -82,6 +79,28 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05b4127cbcaf..eca6f1073889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1), Row(1))) } } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = + SQLConf.SORTMERGE_JOIN.key -> "true" :: + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: + SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } } From a96ba40f7ee1352288ea676d8844e1c8174202eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 14:11:46 -0700 Subject: [PATCH 077/896] [SPARK-10714] [SPARK-8632] [SPARK-10685] [SQL] Refactor Python UDF handling This patch refactors Python UDF handling: 1. Extract the per-partition Python UDF calling logic from PythonRDD into a PythonRunner. PythonRunner itself expects iterator as input/output, and thus has no dependency on RDD. This way, we can use PythonRunner directly in a mapPartitions call, or in the future in an environment without RDDs. 2. Use PythonRunner in Spark SQL's BatchPythonEvaluation. 3. Updated BatchPythonEvaluation to only use its input once, rather than twice. This should fix Python UDF performance regression in Spark 1.5. There are a number of small cleanups I wanted to do when I looked at the code, but I kept most of those out so the diff looks small. This basically implements the approach in https://github.com/apache/spark/pull/8833, but with some code moving around so the correctness doesn't depend on the inner workings of Spark serialization and task execution. Author: Reynold Xin Closes #8835 from rxin/python-iter-refactor. --- .../apache/spark/api/python/PythonRDD.scala | 54 ++++++++++--- .../spark/sql/execution/pythonUDFs.scala | 80 +++++++++++-------- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 69da180593bb..3788d1829758 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( parent: RDD[_], @@ -61,11 +61,39 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") @@ -75,7 +103,7 @@ private[spark] class PythonRDD( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,11 +242,11 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size()) for (include <- pythonIncludes.asScala) { @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index d0411da6fdf5..c35c726bfc50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -329,7 +329,13 @@ case class EvaluatePython( /** * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) @@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val parent = childResults.mapPartitions { iter => + inputRDD.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => + queue.add(row) EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } - } - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - } - }.mapPartitions { iter => + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } + val joined = new JoinedRow - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + joined(queue.poll(), row) } } } From 61d4c07f4becb42f054e588be56ed13239644410 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 16:35:43 -0700 Subject: [PATCH 078/896] [SPARK-10640] History server fails to parse TaskCommitDenied ... simply because the code is missing! Author: Andrew Or Closes #8828 from andrewor14/task-end-reason-json. --- .../scala/org/apache/spark/TaskEndReason.scala | 6 +++++- .../org/apache/spark/util/JsonProtocol.scala | 13 +++++++++++++ .../apache/spark/util/JsonProtocolSuite.scala | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7137246bc34f..9335c5f4160b 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,13 +17,17 @@ package org.apache.spark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 99614a786bd9..40729fa5a4ff 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,6 +362,10 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) + case taskCommitDenied: TaskCommitDenied => + ("Job ID" -> taskCommitDenied.jobID) ~ + ("Partition ID" -> taskCommitDenied.partitionID) ~ + ("Attempt Number" -> taskCommitDenied.attemptNumber) case ExecutorLostFailure(executorId, isNormalExit) => ("Executor ID" -> executorId) ~ ("Normal Exit" -> isNormalExit) @@ -770,6 +774,7 @@ private[spark] object JsonProtocol { val exceptionFailure = Utils.getFormattedClassName(ExceptionFailure) val taskResultLost = Utils.getFormattedClassName(TaskResultLost) val taskKilled = Utils.getFormattedClassName(TaskKilled) + val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) @@ -794,6 +799,14 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled + case `taskCommitDenied` => + // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON + // de/serialization logic was not added until 1.5.1. To provide backward compatibility + // for reading those logs, we need to provide default values for all the fields. + val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). map(_.extract[Boolean]) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 143c1b901df1..a24bf2931cca 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,6 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) + testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) @@ -352,6 +353,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } + // `TaskCommitDenied` was added in 1.3.0 but JSON de/serialization logic was added in 1.5.1 + test("TaskCommitDenied backward compatibility") { + val denied = TaskCommitDenied(1, 2, 3) + val oldDenied = JsonProtocol.taskEndReasonToJson(denied) + .removeField({ _._1 == "Job ID" }) + .removeField({ _._1 == "Partition ID" }) + .removeField({ _._1 == "Attempt Number" }) + val expectedDenied = TaskCommitDenied(-1, -1, -1) + assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -577,6 +589,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => + case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), + TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => + assert(jobId1 === jobId2) + assert(partitionId1 === partitionId2) + assert(attemptNumber1 === attemptNumber2) case (ExecutorLostFailure(execId1, isNormalExit1), ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) From 84f81e035e1dab1b42c36563041df6ba16e7b287 Mon Sep 17 00:00:00 2001 From: Zhichao Li Date: Tue, 22 Sep 2015 19:41:57 -0700 Subject: [PATCH 079/896] [SPARK-10310] [SQL] Fixes script transformation field/line delimiters **Please attribute this PR to `Zhichao Li `.** This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer classes. Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`, and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be addressed in separate PR(s). Author: Cheng Lian Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters. --- .../org/apache/spark/sql/hive/HiveQl.scala | 52 ++++++++++--- .../hive/execution/ScriptTransformation.scala | 75 +++++++++++++++---- .../resources/data/scripts/test_transform.py | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 39 ++++++++++ .../execution/ScriptTransformationSuite.scala | 2 + 5 files changed, 152 insertions(+), 22 deletions(-) create mode 100755 sql/hive/src/test/resources/data/scripts/test_transform.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d5cd7e98b526..256440a9a2e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException @@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, None, Nil) + (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (BaseSemanticAnalyzer.unescapeSQLString(name), BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) - val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) + + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) + } else { + None + } + + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None + } + val schema = HiveScriptIOSchema( inRowFormat, outRowFormat, inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, schemaLess) + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) Some( logical.ScriptTransformation( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 32bddbaeaeaf..b30117f0de99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -24,20 +24,22 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} import org.apache.spark.{Logging, TaskContext} /** @@ -58,6 +60,8 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) @@ -67,6 +71,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHiveConf = serializedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -96,7 +101,8 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHiveConf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call @@ -109,6 +115,10 @@ case class ScriptTransformation( val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, localHiveConf).orNull + var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { outputSerde.getSerializedClass().newInstance @@ -134,15 +144,25 @@ case class ScriptTransformation( } } else if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject - try { - scriptOutputWritable.readFields(scriptOutputStream) - true - } catch { - case _: EOFException => - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } + + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + writerThread.exception.foreach(throw _) false + } else { + true + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } } else { true @@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + conf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread( TaskContext.setTaskContext(taskContext) val dataOutputStream = new DataOutputStream(outputStream) + @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception @@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } } } outputStream.close() @@ -290,6 +317,8 @@ case class HiveScriptIOSchema ( outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { private val defaultFormat = Map( @@ -347,4 +376,24 @@ case class HiveScriptIOSchema ( serde } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } } diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py new file mode 100755 index 000000000000..ac6d11d8b919 --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transform.py @@ -0,0 +1,6 @@ +import sys + +delim = sys.argv[1] + +for row in sys.stdin: + print(delim.join([w + '#' for w in row[:-1].split(delim)])) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bb02473dd17c..71823e32ad38 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(df, Row("text inside layer 2") :: Nil) } + + test("SPARK-10310: " + + "script transformation using default input/output SerDe and record reader/writer") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + checkAnswer( + sql( + """FROM( + | FROM test SELECT TRANSFORM(a, b) + | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | AS (c STRING, d STRING) + |) t + |SELECT c + """.stripMargin), + (0 until 5).map(i => Row(i + "#"))) + } + + test("SPARK-10310: script transformation using LazySimpleSerDe") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + val df = sql( + """FROM test + |SELECT TRANSFORM(a, b) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |AS (c STRING, d STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + """.stripMargin) + + checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index cb8d0fca8e69..7cfdb886b585 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { outputSerdeClass = None, inputSerdeProps = Seq.empty, outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, schemaLess = false ) From 558e9c7e60a7c0d85ba26634e97562ad2163e91d Mon Sep 17 00:00:00 2001 From: Matt Hagen Date: Tue, 22 Sep 2015 21:14:25 -0700 Subject: [PATCH 080/896] [SPARK-10663] Removed unnecessary invocation of DataFrame.toDF method. The Scala example under the "Example: Pipeline" heading in this document initializes the "test" variable to a DataFrame. Because test is already a DF, there is not need to call test.toDF as the example does in a subsequent line: model.transform(test.toDF). So, I removed the extraneous toDF invocation. Author: Matt Hagen Closes #8875 from hagenhaus/SPARK-10663. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 0427ac6695aa..fd3a6167bc65 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -475,7 +475,7 @@ val test = sqlContext.createDataFrame(Seq( )).toDF("id", "text") // Make predictions on test documents. -model.transform(test.toDF) +model.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => From 5548a254755bb84edae2768b94ab1816e1b49b91 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Sep 2015 22:44:09 -0700 Subject: [PATCH 081/896] [SPARK-10652] [SPARK-10742] [STREAMING] Set meaningful job descriptions for all streaming jobs Here is the screenshot after adding the job descriptions to threads that run receivers and the scheduler thread running the batch jobs. ## All jobs page * Added job descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9924165/cda4a372-5cb1-11e5-91ca-d43a32c699e9.png) ## All stages page * Added stage descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9923814/2cce266a-5cae-11e5-8a3f-dad84d06c50e.png) ## Streaming batch details page * Added the +details link ![image](https://cloud.githubusercontent.com/assets/663212/9921977/24014a32-5c98-11e5-958e-457b6c38065b.png) Author: Tathagata Das Closes #8791 from tdas/SPARK-10652. --- .../scala/org/apache/spark/ui/UIUtils.scala | 62 ++++++++++++++++- .../apache/spark/ui/jobs/AllJobsPage.scala | 14 ++-- .../org/apache/spark/ui/jobs/StageTable.scala | 7 +- .../org/apache/spark/ui/UIUtilsSuite.scala | 66 +++++++++++++++++++ .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/scheduler/JobScheduler.scala | 15 ++++- .../streaming/scheduler/ReceiverTracker.scala | 5 +- .../apache/spark/streaming/ui/BatchPage.scala | 33 ++++++---- .../streaming/StreamingContextSuite.scala | 2 +- 9 files changed, 179 insertions(+), 29 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f2da41772410..21dc8f0b6548 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -395,4 +397,60 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like } private def createExecutorTable() : Seq[Node] = { From 9631ca35275b0ce8a5219f975907ac36ed11f528 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 18 Nov 2015 08:59:20 +0000 Subject: [PATCH 758/896] [SPARK-11652][CORE] Remote code execution with InvokerTransformer Update to Commons Collections 3.2.2 to avoid any potential remote code execution vulnerability Author: Sean Owen Closes #9731 from srowen/SPARK-11652. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 940e2d8740bf..ad849112ce76 100644 --- a/pom.xml +++ b/pom.xml @@ -162,6 +162,8 @@ 3.1 3.4.1 + + 3.2.2 2.10.5 2.10 ${scala.version} @@ -475,6 +477,11 @@ commons-math3 ${commons.math3.version} + + org.apache.commons + commons-collections + ${commons.collections.version} + org.apache.ivy ivy From 1429e0a2b562469146b6fa06051c85a00092e5b8 Mon Sep 17 00:00:00 2001 From: Viveka Kulharia Date: Wed, 18 Nov 2015 09:10:15 +0000 Subject: [PATCH 759/896] rmse was wrongly calculated It was multiplying with U instaed of dividing by U Author: Viveka Kulharia Closes #9771 from vivkul/patch-1. --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 1c3a787bd0e9..205ca02962be 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -36,7 +36,7 @@ def rmse(R, ms, us): diff = R - ms * us.T - return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + return np.sqrt(np.sum(np.power(diff, 2)) / (M * U)) def update(i, vec, mat, ratings): From 3a6807fdf07b0e73d76502a6bd91ad979fde8b61 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 18 Nov 2015 08:18:54 -0800 Subject: [PATCH 760/896] =?UTF-8?q?[SPARK-11804]=20[PYSPARK]=20Exception?= =?UTF-8?q?=20raise=20when=20using=20Jdbc=20predicates=20opt=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion in PySpark Author: Jeff Zhang Closes #9791 from zjffdu/SPARK-11804. --- python/pyspark/sql/readwriter.py | 10 +++++----- python/pyspark/sql/utils.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7b8ddb9feba3..e8f0d7ec7703 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -26,6 +26,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * +from pyspark.sql import utils __all__ = ["DataFrameReader", "DataFrameWriter"] @@ -131,9 +132,7 @@ def load(self, path=None, format=None, schema=None, **options): if type(path) == list: paths = path gateway = self._sqlContext._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] + jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) return self._df(self._jreader.load(jpaths)) else: return self._df(self._jreader.load(path)) @@ -269,8 +268,9 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) - return self._df(self._jreader.jdbc(url, table, arr, jprop)) + gateway = self._sqlContext._sc._gateway + jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) + return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index c4fda8bd3b89..b0a0373372d2 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -71,3 +71,16 @@ def install_exception_handler(): patched = capture_sql_exception(original) # only patch the one used in in py4j.java_gateway (call Java API) py4j.java_gateway.get_return_value = patched + + +def toJArray(gateway, jtype, arr): + """ + Convert python list to java type array + :param gateway: Py4j Gateway + :param jtype: java type of element in array + :param arr: python type list + """ + jarr = gateway.new_array(jtype, len(arr)) + for i in range(0, len(arr)): + jarr[i] = arr[i] + return jarr From a97d6f3a5861e9f2bbe36957e3b39f835f3e214c Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 18 Nov 2015 08:32:03 -0800 Subject: [PATCH 761/896] [SPARK-11281][SPARKR] Add tests covering the issue. The goal of this PR is to add tests covering the issue to ensure that is was resolved by [SPARK-11086](https://issues.apache.org/jira/browse/SPARK-11086). Author: zero323 Closes #9743 from zero323/SPARK-11281-tests. --- R/pkg/inst/tests/test_sparkSQL.R | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8ff06276599e..87ab33f6384b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -229,7 +229,7 @@ test_that("create DataFrame from list or data.frame", { df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) - l <- list(list(a=1, b=2), list(a=3, b=4)) + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) @@ -292,11 +292,15 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { - ldf <- data.frame(row.names=1:2) + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) + ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) + sdf <- createDataFrame(sqlContext, ldf) + collected <- collect(sdf) - expect_equivalent(ldf, collect(sdf)) + expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) + expect_equal(ldf$an_envir, collected$an_envir) }) # For test map type and struct type in DataFrame From 224723e6a8b198ef45d6c5ca5d2f9c61188ada8f Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 18 Nov 2015 08:41:45 -0800 Subject: [PATCH 762/896] [SPARK-11773][SPARKR] Implement collection functions in SparkR. Author: Sun Rui Closes #9764 from sun-rui/SPARK-11773. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 109 ++++++++++++++++++++++--------- R/pkg/R/generics.R | 10 ++- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 10 +++ 6 files changed, 100 insertions(+), 35 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 2ee7d6f94f1b..260c9edce62e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -98,6 +98,7 @@ exportMethods("%in%", "add_months", "alias", "approxCountDistinct", + "array_contains", "asc", "ascii", "asin", @@ -215,6 +216,7 @@ exportMethods("%in%", "sinh", "size", "skewness", + "sort_array", "soundex", "stddev", "stddev_pop", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fd105ba5bc9b..34177e3cdd94 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2198,4 +2198,4 @@ setMethod("coltypes", rTypes[naIndices] <- types[naIndices] rTypes - }) \ No newline at end of file + }) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 3d0255a62f15..ff0f438045c1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -373,22 +373,6 @@ setMethod("exp", column(jc) }) -#' explode -#' -#' Creates a new row for each element in the given array or map column. -#' -#' @rdname explode -#' @name explode -#' @family collection_funcs -#' @export -#' @examples \dontrun{explode(df$c)} -setMethod("explode", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) - column(jc) - }) - #' expm1 #' #' Computes the exponential of the given value minus one. @@ -980,22 +964,6 @@ setMethod("sinh", column(jc) }) -#' size -#' -#' Returns length of array or map. -#' -#' @rdname size -#' @name size -#' @family collection_funcs -#' @export -#' @examples \dontrun{size(df$c)} -setMethod("size", - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) - column(jc) - }) - #' skewness #' #' Aggregate function: returns the skewness of the values in a group. @@ -2365,3 +2333,80 @@ setMethod("rowNumber", jc <- callJStatic("org.apache.spark.sql.functions", "rowNumber") column(jc) }) + +###################### Collection functions###################### + +#' array_contains +#' +#' Returns true if the array contain the value. +#' +#' @param x A Column +#' @param value A value to be checked if contained in the column +#' @rdname array_contains +#' @name array_contains +#' @family collection_funcs +#' @export +#' @examples \dontrun{array_contains(df$c, 1)} +setMethod("array_contains", + signature(x = "Column", value = "ANY"), + function(x, value) { + jc <- callJStatic("org.apache.spark.sql.functions", "array_contains", x@jc, value) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' sort_array +#' +#' Sorts the input array for the given column in ascending order, +#' according to the natural ordering of the array elements. +#' +#' @param x A Column to sort +#' @param asc A logical flag indicating the sorting order. +#' TRUE, sorting is in ascending order. +#' FALSE, sorting is in descending order. +#' @rdname sort_array +#' @name sort_array +#' @family collection_funcs +#' @export +#' @examples +#' \dontrun{ +#' sort_array(df$c) +#' sort_array(df$c, FALSE) +#' } +setMethod("sort_array", + signature(x = "Column"), + function(x, asc = TRUE) { + jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index afdeffc2abd8..0dcd05438222 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -644,6 +644,10 @@ setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) #' @export setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) +#' @rdname array_contains +#' @export +setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) + #' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -961,6 +965,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @export setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname sort_array +#' @export +setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1076,4 +1084,4 @@ setGeneric("with") #' @rdname coltypes #' @export -setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index db3b2c4bbd79..45c77a86c958 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -635,4 +635,4 @@ assignNewEnv <- function(data) { assign(x = cols[i], value = data[, cols[i]], envir = env) } env -} \ No newline at end of file +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 87ab33f6384b..d9a94faff7ac 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -878,6 +878,16 @@ test_that("column functions", { df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") + + # Test array_contains() and sort_array() + df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] + expect_equal(result, c(TRUE, FALSE)) + + result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) + result <- collect(select(df, sort_array(df[[1]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) }) # test_that("column binary mathfunctions", { From 3cca5ffb3d60d5de9a54bc71cf0b8279898936d2 Mon Sep 17 00:00:00 2001 From: Hurshal Patel Date: Wed, 18 Nov 2015 09:28:59 -0800 Subject: [PATCH 763/896] [SPARK-11195][CORE] Use correct classloader for TaskResultGetter Make sure we are using the context classloader when deserializing failed TaskResults instead of the Spark classloader. The issue is that `enqueueFailedTask` was using the incorrect classloader which results in `ClassNotFoundException`. Adds a test in TaskResultGetterSuite that compiles a custom exception, throws it on the executor, and asserts that Spark handles the TaskResult deserialization instead of returning `UnknownReason`. See #9367 for previous comments See SPARK-11195 for a full repro Author: Hurshal Patel Closes #9779 from choochootrain/spark-11195-master. --- .../scala/org/apache/spark/TestUtils.scala | 11 ++-- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../scheduler/TaskResultGetterSuite.scala | 65 ++++++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index acfe751f6c74..43c89b258f2f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} import java.nio.charset.StandardCharsets +import java.nio.file.Paths import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} @@ -83,15 +84,15 @@ private[spark] object TestUtils { } /** - * Create a jar file that contains this set of files. All files will be located at the root - * of the jar. + * Create a jar file that contains this set of files. All files will be located in the specified + * directory or at the root of the jar. */ - def createJar(files: Seq[File], jarFile: File): URL = { + def createJar(files: Seq[File], jarFile: File, directoryPrefix: Option[String] = None): URL = { val jarFileStream = new FileOutputStream(jarFile) val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(file.getName) + val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -123,7 +124,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - // Calling this outputs a class file in pwd. It's easier to just rename the file than + // Calling this outputs a class file in pwd. It's easier to just rename the files than // build a custom FileManager that controls the output location. val options = if (classpathUrls.nonEmpty) { Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2e..f4965994d827 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -103,16 +103,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { + val loader = Utils.getContextOrSparkClassLoader try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, Utils.getSparkClassLoader) + serializedData, loader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastrophic // if we can't deserialize the reason. - val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 815caa79ff52..bc72c3685e8c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.io.File +import java.net.URL import java.nio.ByteBuffer import scala.concurrent.duration._ @@ -26,8 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId +import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -119,5 +123,64 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // Make sure two tasks were run (one failed one, and a second retried one). assert(scheduler.nextTaskId.get() === 2) } + + /** + * Make sure we are using the context classloader when deserializing failed TaskResults instead + * of the Spark classloader. + + * This test compiles a jar containing an exception and tests that when it is thrown on the + * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown + * exception as the cause. + + * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing + * the exception, resulting in an UnknownReason for the TaskEndResult. + */ + test("failed task deserialized with the correct classloader (SPARK-11195)") { + // compile a small jar containing an exception that will be thrown on an executor. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "repro/") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + """package repro; + | + |public class MyException extends Exception { + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro")) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + loader.addURL(jarFile.toURI.toURL) + Thread.currentThread().setContextClassLoader(loader) + val excClass: Class[_] = Utils.classForName("repro.MyException") + + // NOTE: we must run the cluster with "local" so that the executor can load the compiled + // jar. + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + val exc = excClass.newInstance().asInstanceOf[Exception] + throw exc + } + + // the driver should not have any problems resolving the exception class and determining + // why the task failed. + val exceptionMessage = intercept[SparkException] { + rdd.collect() + }.getMessage + + val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + + assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) + assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } } From cffb899c4397ecccedbcc41e7cf3da91f953435a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:15:50 -0800 Subject: [PATCH 764/896] [SPARK-11803][SQL] fix Dataset self-join When we resolve the join operator, we may change the output of right side if self-join is detected. So in `Dataset.joinWith`, we should resolve the join operator first, and then get the left output and right output from it, instead of using `left.output` and `right.output` directly. Author: Wenchen Fan Closes #9806 from cloud-fan/self-join. --- .../main/scala/org/apache/spark/sql/Dataset.scala | 14 +++++++++----- .../scala/org/apache/spark/sql/DatasetSuite.scala | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 817c20fdbb9f..b644f6ad3096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -498,13 +498,17 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan + val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(left.output.head, "_1")() - case _ => Alias(CreateStruct(left.output), "_1")() + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() } val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(right.output.head, "_2")() - case _ => Alias(CreateStruct(right.output), "_2")() + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() } @@ -513,7 +517,7 @@ class Dataset[T] private[sql]( withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, - Join(left, right, Inner, Some(condition.expr))) + joined.analyzed) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index a522894c374f..198962b8fb75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -347,7 +347,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } - ignore("self join") { + test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) @@ -360,15 +360,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.groupBy(p => p).count().collect().toSeq == Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - ignore("kryo encoder self join") { + test("kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == Set( (KryoData(1), KryoData(1)), From 33b837333435ceb0c04d1f361a5383c4fe6a5a75 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:23:12 -0800 Subject: [PATCH 765/896] [SPARK-11725][SQL] correctly handle null inputs for UDF If user use primitive parameters in UDF, there is no way for him to do the null-check for primitive inputs, so we are assuming the primitive input is null-propagatable for this case and return null if the input is null. Author: Wenchen Fan Closes #9770 from cloud-fan/udf. --- .../spark/sql/catalyst/ScalaReflection.scala | 9 ++++ .../sql/catalyst/analysis/Analyzer.scala | 32 +++++++++++++- .../sql/catalyst/expressions/ScalaUDF.scala | 6 +++ .../sql/catalyst/ScalaReflectionSuite.scala | 17 +++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 44 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++ 6 files changed, 121 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b3dd351e38e..38828e59a215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -719,6 +719,15 @@ trait ScalaReflection { } } + /** + * Returns classes of input parameters of scala function object. + */ + def getParameterTypes(func: AnyRef): Seq[Class[_]] = { + val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge) + assert(methods.length == 1) + methods.head.getParameterTypes + } + def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. case obj: Boolean => BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f4670b55bdb..f00c451b5981 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} +import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -85,6 +85,8 @@ class Analyzer( extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), + Batch("UDF", Once, + HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -1063,6 +1065,34 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the + * null check. When user defines a UDF with primitive parameters, there is no way to tell if the + * primitive parameter is null or not, so here we assume the primitive input is null-propagatable + * and we should return null if the input is null. + */ + object HandleNullInputsForUDF extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p // Skip unresolved nodes. + + case plan => plan transformExpressionsUp { + + case udf @ ScalaUDF(func, _, inputs, _) => + val parameterTypes = ScalaReflection.getParameterTypes(func) + assert(parameterTypes.length == inputs.length) + + val inputsNullCheck = parameterTypes.zip(inputs) + // TODO: skip null handling for not-nullable primitive inputs after we can completely + // trust the `nullable` information. + // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } + .filter { case (cls, _) => cls.isPrimitive } + .map { case (_, expr) => IsNull(expr) } + .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) + inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3388cc20a980..03b89221ef2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * User-defined function. + * @param function The user defined scala function to run. + * Note that if you use primitive parameters, you are not able to check if it is + * null or not, and the UDF will return null for you if the primitive input is + * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. + * @param children The input expressions of this UDF. + * @param inputTypes The expected input types of this UDF. */ case class ScalaUDF( function: AnyRef, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 3b848cfdf737..4ea410d492b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType)) } } + + test("get parameter type from a function object") { + val primitiveFunc = (i: Int, j: Long) => "x" + val primitiveTypes = getParameterTypes(primitiveFunc) + assert(primitiveTypes.forall(_.isPrimitive)) + assert(primitiveTypes === Seq(classOf[Int], classOf[Long])) + + val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x" + val boxedTypes = getParameterTypes(boxedFunc) + assert(boxedTypes.forall(!_.isPrimitive)) + assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long])) + + val anyFunc = (i: Any, j: AnyRef) => "x" + val anyTypes = getParameterTypes(anyFunc) + assert(anyTypes.forall(!_.isPrimitive)) + assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 65f09b46afae..08586a97411a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest { ) assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val string = testRelation2.output(0) + val double = testRelation2.output(2) + val short = testRelation2.output(4) + val nullResult = Literal.create(null, StringType) + + def checkUDF(udf: Expression, transformed: Expression): Unit = { + checkAnalysis( + Project(Alias(udf, "")() :: Nil, testRelation2), + Project(Alias(transformed, "")() :: Nil, testRelation2) + ) + } + + // non-primitive parameters do not need special null handling + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil) + val expected1 = udf1 + checkUDF(udf1, expected1) + + // only primitive parameter needs special null handling + val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil) + val expected2 = If(IsNull(double), nullResult, udf2) + checkUDF(udf2, expected2) + + // special null handling should apply to all primitive parameters + val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil) + val expected3 = If( + IsNull(short) || IsNull(double), + nullResult, + udf3) + checkUDF(udf3, expected3) + + // we can skip special null handling for primitive parameters that are not nullable + // TODO: this is disabled for now as we can not completely trust `nullable`. + val udf4 = ScalaUDF( + (s: Short, d: Double) => "x", + StringType, + short :: double.withNullability(false) :: Nil) + val expected4 = If( + IsNull(short), + nullResult, + udf4) + // checkUDF(udf4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 35cdab50bdec..5a7f24684d1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(df("*")), Row(1, "a")) checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a")) } + + test("SPARK-11725: correctly handle null inputs for ScalaUDF") { + val df = Seq( + new java.lang.Integer(22) -> "John", + null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name") + + val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { + (i: java.lang.Integer) => if (i == null) null else i * 2 + } + checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil) + + val primitiveUDF = udf((i: Int) => i * 2) + checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) + } } From dbf428c87ab34b6f76c75946043bdf5f60c9b1b3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2015 10:33:17 -0800 Subject: [PATCH 766/896] [SPARK-11795][SQL] combine grouping attributes into a single NamedExpression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat. However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`. Author: Wenchen Fan Closes #9792 from cloud-fan/agg. --- .../main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 +++++++-- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 ++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c66162ee2148..3f84e22a1025 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedTEncoder, dataAttributes).named) - val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val keyColumn = if (groupingAttributes.length > 1) { + Alias(CreateStruct(groupingAttributes), "key")() + } else { + groupingAttributes.head + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) new Dataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 198962b8fb75..b6db583dfe01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - ignore("Dataset should set the resolved encoders internally for maps") { - // TODO: Enable this once we fix SPARK-11793. + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() @@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds, - (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { From 90a7519daaa7f4ee3be7c5a9aa244120811ff6eb Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 18 Nov 2015 11:35:41 -0800 Subject: [PATCH 767/896] [MINOR][BUILD] Ignore ensime cache Using ENSIME, I often have `.ensime_cache` polluting my source tree. This PR simply adds the cache directory to `.gitignore` Author: Jakob Odersky Closes #9708 from jodersky/master. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 08f2d8f7543f..07524bc429e9 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ spark-tests.log streaming-tests.log dependency-reduced-pom.xml .ensime +.ensime_cache/ .ensime_lucene checkpoint derby.log From 6f99522d13d8db9fcc767f7c3189557b9a53d283 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 11:49:12 -0800 Subject: [PATCH 768/896] [SPARK-11792] [SQL] [FOLLOW-UP] Change SizeEstimation to KnownSizeEstimation and make estimatedSize return Long instead of Option[Long] https://issues.apache.org/jira/browse/SPARK-11792 The main changes include: * Renaming `SizeEstimation` to `KnownSizeEstimation`. Hopefully this new name has more information. * Making `estimatedSize` return `Long` instead of `Option[Long]`. * In `UnsaveHashedRelation`, `estimatedSize` will delegate the work to `SizeEstimator` if we have not created a `BytesToBytesMap`. Since we will put `UnsaveHashedRelation` to `BlockManager`, it is generally good to let it provide a more accurate size estimation. Also, if we do not put `BytesToBytesMap` directly into `BlockerManager`, I feel it is not really necessary to make `BytesToBytesMap` extends `KnownSizeEstimation`. Author: Yin Huai Closes #9813 from yhuai/SPARK-11792-followup. --- .../org/apache/spark/util/SizeEstimator.scala | 30 ++++++++++--------- .../spark/util/SizeEstimatorSuite.scala | 14 ++------- .../sql/execution/joins/HashedRelation.scala | 12 +++++--- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index c3a2675ee5f4..09864e3f8392 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -36,9 +36,14 @@ import org.apache.spark.util.collection.OpenHashSet * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. * If `estimatedSize` does not return [[None]], [[SizeEstimator]] will use the returned size * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. + * The difference between a [[KnownSizeEstimation]] and + * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to + * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without + * using [[SizeEstimator]]. */ -private[spark] trait SizeEstimation { - def estimatedSize: Option[Long] +private[spark] trait KnownSizeEstimation { + def estimatedSize: Long } /** @@ -209,18 +214,15 @@ object SizeEstimator extends Logging { // the size estimator since it references the whole REPL. Do nothing in this case. In // general all ClassLoaders and Classes will be shared between objects anyway. } else { - val estimatedSize = obj match { - case s: SizeEstimation => s.estimatedSize - case _ => None - } - if (estimatedSize.isDefined) { - state.size += estimatedSize.get - } else { - val classInfo = getClassInfo(cls) - state.size += alignSize(classInfo.shellSize) - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } + obj match { + case s: KnownSizeEstimation => + state.size += s.estimatedSize + case _ => + val classInfo = getClassInfo(cls) + state.size += alignSize(classInfo.shellSize) + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } } } } diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 9b6261af123e..101610e38014 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -60,16 +60,10 @@ class DummyString(val arr: Array[Char]) { @transient val hash32: Int = 0 } -class DummyClass8 extends SizeEstimation { +class DummyClass8 extends KnownSizeEstimation { val x: Int = 0 - override def estimatedSize: Option[Long] = Some(2015) -} - -class DummyClass9 extends SizeEstimation { - val x: Int = 0 - - override def estimatedSize: Option[Long] = None + override def estimatedSize: Long = 2015 } class SizeEstimatorSuite @@ -231,9 +225,5 @@ class SizeEstimatorSuite // DummyClass8 provides its size estimation. assertResult(2015)(SizeEstimator.estimate(new DummyClass8)) assertResult(20206)(SizeEstimator.estimate(Array.fill(10)(new DummyClass8))) - - // DummyClass9 does not provide its size estimation. - assertResult(16)(SizeEstimator.estimate(new DummyClass9)) - assertResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass9))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 49ae09bf5378..aebfea583240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimation, Utils} +import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -190,7 +190,7 @@ private[execution] object HashedRelation { private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation - with SizeEstimation + with KnownSizeEstimation with Externalizable { private[joins] def this() = this(null) // Needed for serialization @@ -217,8 +217,12 @@ private[joins] final class UnsafeHashedRelation( } } - override def estimatedSize: Option[Long] = { - Option(binaryMap).map(_.getTotalMemoryConsumption) + override def estimatedSize: Long = { + if (binaryMap != null) { + binaryMap.getTotalMemoryConsumption + } else { + SizeEstimator.estimate(hashTable) + } } override def get(key: InternalRow): Seq[InternalRow] = { From 94624eacb0fdbbe210894151a956f8150cdf527e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 18 Nov 2015 11:53:28 -0800 Subject: [PATCH 769/896] [SPARK-11739][SQL] clear the instantiated SQLContext Currently, if the first SQLContext is not removed after stopping SparkContext, a SQLContext could set there forever. This patch make this more robust. Author: Davies Liu Closes #9706 from davies/clear_context. --- .../scala/org/apache/spark/sql/SQLContext.scala | 17 +++++++++++------ .../spark/sql/MultiSQLContextsSuite.scala | 5 ++--- .../execution/ExchangeCoordinatorSuite.scala | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd1fdc4edb39..39471d2fb79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1229,7 +1229,7 @@ class SQLContext private[sql]( // construction of the instance. sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext(self) + SQLContext.clearInstantiatedContext() } }) @@ -1270,13 +1270,13 @@ object SQLContext { */ def getOrCreate(sparkContext: SparkContext): SQLContext = { val ctx = activeContext.get() - if (ctx != null) { + if (ctx != null && !ctx.sparkContext.isStopped) { return ctx } synchronized { val ctx = instantiatedContext.get() - if (ctx == null) { + if (ctx == null || ctx.sparkContext.isStopped) { new SQLContext(sparkContext) } else { ctx @@ -1284,12 +1284,17 @@ object SQLContext { } } - private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(sqlContext, null) + private[sql] def clearInstantiatedContext(): Unit = { + instantiatedContext.set(null) } private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - instantiatedContext.compareAndSet(null, sqlContext) + synchronized { + val ctx = instantiatedContext.get() + if (ctx == null || ctx.sparkContext.isStopped) { + instantiatedContext.set(sqlContext) + } + } } private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 0e8fcb6a858b..34c5c68fd1c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { testNewSession(rootSQLContext) testNewSession(rootSQLContext) testCreatingNewSQLContext(allowMultipleSQLContexts) - - SQLContext.clearInstantiatedContext(rootSQLContext) } finally { sc.stop() + SQLContext.clearInstantiatedContext() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 25f2f5caeed1..b96d50a70b85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() SQLContext.clearActive() - originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + SQLContext.clearInstantiatedContext() } override protected def afterAll(): Unit = { From 31921e0f0bd559d042148d1ea32f865fb3068f38 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 18 Nov 2015 12:09:54 -0800 Subject: [PATCH 770/896] [SPARK-4557][STREAMING] Spark Streaming foreachRDD Java API method should accept a VoidFunction<...> Currently streaming foreachRDD Java API uses a function prototype requiring a return value of null. This PR deprecates the old method and uses VoidFunction to allow for more concise declaration. Also added VoidFunction2 to Java API in order to use in Streaming methods. Unit test is added for using foreachRDD with VoidFunction, and changes have been tested with Java 7 and Java 8 using lambdas. Author: Bryan Cutler Closes #9488 from BryanCutler/foreachRDD-VoidFunction-SPARK-4557. --- .../api/java/function/VoidFunction2.java | 27 ++++++++++++ .../apache/spark/streaming/Java8APISuite.java | 26 ++++++++++++ project/MimaExcludes.scala | 4 ++ .../streaming/api/java/JavaDStreamLike.scala | 24 ++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 41 ++++++++++++++++++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java new file mode 100644 index 000000000000..6c576ab67845 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A two-argument function that takes arguments of type T1 and T2 with no return value. + */ +public interface VoidFunction2 extends Serializable { + public void call(T1 v1, T2 v2) throws Exception; +} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 163ae92c12c6..4eee97bc8961 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -28,6 +28,7 @@ import org.junit.Assert; import org.junit.Test; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -360,6 +361,31 @@ public void testFlatMap() { assertOrderInvariantEquals(expected, result); } + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(rdd -> { + accumRdd.add(1); + rdd.foreach(x -> accumEle.add(1)); + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD((rdd, time) -> null); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @Test public void testPairFlatMap() { List> inputData = Arrays.asList( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eb70d27c34c2..bb45d1bb1214 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -142,6 +142,10 @@ object MimaExcludes { "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") + ) ++ Seq( + // SPARK-4557 Changed foreachRDD to use VoidFunction + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") ) case v if v.startsWith("1.5") => Seq( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index edfa474677f1..84acec7d8e33 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, VoidFunction => JVoidFunction, VoidFunction2 => JVoidFunction2, _} import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ import org.apache.spark.streaming.api.java.JavaDStream._ @@ -308,7 +308,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction[R])", "1.6.0") def foreachRDD(foreachFunc: JFunction[R, Void]) { dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) } @@ -316,11 +319,30 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 1.6.0, replaced by foreachRDD(JVoidFunction2) */ + @deprecated("Use foreachRDD(foreachFunc: JVoidFunction2[R, Time])", "1.6.0") def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction[R]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JVoidFunction2[R, Time]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + } + /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index c5217149224e..609bb4413b6b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -37,7 +37,9 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; +import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -45,7 +47,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -768,6 +769,44 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testForeachRDD() { + final Accumulator accumRdd = ssc.sc().accumulator(0); + final Accumulator accumEle = ssc.sc().accumulator(0); + List> inputData = Arrays.asList( + Arrays.asList(1,1,1), + Arrays.asList(1,1,1)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaTestUtils.attachTestOutputStream(stream.count()); // dummy output + + stream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + accumRdd.add(1); + rdd.foreach(new VoidFunction() { + @Override + public void call(Integer i) { + accumEle.add(1); + } + }); + } + }); + + // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java + stream.foreachRDD(new VoidFunction2, Time>() { + @Override + public void call(JavaRDD rdd, Time time) { + } + }); + + JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(2, accumRdd.value().intValue()); + Assert.assertEquals(6, accumEle.value().intValue()); + } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { From a416e41e285700f861559d710dbf857405bfddf6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 12:50:29 -0800 Subject: [PATCH 771/896] [SPARK-11809] Switch the default Mesos mode to coarse-grained mode Based on my conversions with people, I believe the consensus is that the coarse-grained mode is more stable and easier to reason about. It is best to use that as the default rather than the more flaky fine-grained mode. Author: Reynold Xin Closes #9795 from rxin/SPARK-11809. --- .../scala/org/apache/spark/SparkContext.scala | 2 +- docs/job-scheduling.md | 2 +- docs/running-on-mesos.md | 27 ++++++++++++------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b5645b08f92d..ab374cb71286 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2710,7 +2710,7 @@ object SparkContext extends Logging { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index a3c34cb6796f..36327c6efeaf 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -47,7 +47,7 @@ application is not running tasks on a machine, other applications may run tasks is useful when you expect large numbers of not overly active applications, such as shell sessions from separate users. However, it comes with a risk of less predictable latency, because it may take a while for an application to gain back cores on one node when it has work to do. To use this mode, simply use a -`mesos://` URL without setting `spark.mesos.coarse` to true. +`mesos://` URL and set `spark.mesos.coarse` to false. Note that none of the modes currently provide memory sharing across applications. If you would like to share data this way, we recommend running a single server application that can serve multiple requests by querying diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 5be208cf3461..a197d0e37302 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -161,21 +161,15 @@ Note that jars or python files that are passed to spark-submit should be URIs re # Mesos Run Modes -Spark can run over Mesos in two modes: "fine-grained" (default) and "coarse-grained". +Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". -In "fine-grained" mode (default), each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. - -The "coarse-grained" mode will instead launch only *one* long-running Spark task on each Mesos +The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup overhead, but at the cost of reserving the Mesos resources for the complete duration of the application. -To run in coarse-grained mode, set the `spark.mesos.coarse` property in your -[SparkConf](configuration.html#spark-properties): +Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true +to turn it on explictly in [SparkConf](configuration.html#spark-properties): {% highlight scala %} conf.set("spark.mesos.coarse", "true") @@ -186,6 +180,19 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows +multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, +where each application gets more or fewer machines as it ramps up and down, but it comes with an +additional overhead in launching each task. This mode may be inappropriate for low-latency +requirements like interactive queries or serving web requests. + +To run in coarse-grained mode, set the `spark.mesos.coarse` property to false in your +[SparkConf](configuration.html#spark-properties): + +{% highlight scala %} +conf.set("spark.mesos.coarse", "false") +{% endhighlight %} + You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} From 7c5b641808740ba5eed05ba8204cdbaf3fc579f5 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 18 Nov 2015 12:53:22 -0800 Subject: [PATCH 772/896] [SPARK-10745][CORE] Separate configs between shuffle and RPC [SPARK-6028](https://issues.apache.org/jira/browse/SPARK-6028) uses network module to implement RPC. However, there are some configurations named with `spark.shuffle` prefix in the network module. This PR refactors them to make sure the user can control them in shuffle and RPC separately. The user can use `spark.rpc.*` to set the configuration for netty RPC. Author: Shixiong Zhu Closes #9481 from zsxwing/SPARK-10745. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- .../network/netty/SparkTransportConf.scala | 12 ++-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 8 +-- .../mesos/CoarseMesosSchedulerBackend.scala | 2 +- .../shuffle/FileShuffleBlockResolver.scala | 2 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../spark/network/util/TransportConf.java | 65 ++++++++++++++----- .../network/ChunkFetchIntegrationSuite.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 2 +- .../spark/network/RpcIntegrationSuite.java | 2 +- .../org/apache/spark/network/StreamSuite.java | 2 +- .../network/TransportClientFactorySuite.java | 6 +- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../network/sasl/SaslIntegrationSuite.java | 2 +- .../ExternalShuffleBlockResolverSuite.java | 2 +- .../shuffle/ExternalShuffleCleanupSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 2 +- .../shuffle/RetryingBlockFetcherSuite.java | 2 +- .../network/yarn/YarnShuffleService.java | 2 +- 23 files changed, 84 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index a039d543c35e..e8a1e35c3fc4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -45,7 +45,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val transportConf = + SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler, true) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 70a42f9045e6..b0694e3c6c8a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -41,7 +41,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index cef203006d68..84833f59d7af 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -40,23 +40,23 @@ object SparkTransportConf { /** * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param _conf the [[SparkConf]] + * @param module the module name * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. */ - def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { val conf = _conf.clone // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily // assuming we have all the machine's cores). // NB: Only set if serverThreads/clientThreads not already set. val numThreads = defaultNumThreads(numUsableCores) - conf.set("spark.shuffle.io.serverThreads", - conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) - conf.set("spark.shuffle.io.clientThreads", - conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) + conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) - new TransportConf(new ConfigProvider { + new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) }) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22..3e0c49796950 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,16 +22,13 @@ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -49,7 +46,8 @@ private[netty] class NettyRpcEnv( securityManager: SecurityManager) extends RpcEnv(conf) with Logging { private val transportConf = SparkTransportConf.fromSparkConf( - conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), + "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2de9b6a65169..7d08eae0b487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -109,7 +109,7 @@ private[spark] class CoarseMesosSchedulerBackend( private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { Some(new MesosExternalShuffleClient( - SparkTransportConf.fromSparkConf(conf), + SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled())) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 39fadd878351..cc5f933393ad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -46,7 +46,7 @@ private[spark] trait ShuffleWriterGroup { private[spark] class FileShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver with Logging { - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") private lazy val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 05b1eed7f3be..fadb8fe7ed0a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -47,7 +47,7 @@ private[spark] class IndexShuffleBlockResolver( private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 661c706af32b..ab0007fb7899 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -122,7 +122,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 231f4631e0a4..1c775bcb3d9c 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,7 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) + val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3b2eff377955..115135d44adb 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -23,18 +23,53 @@ * A central location that tracks all the settings we expose to users. */ public class TransportConf { + + private final String SPARK_NETWORK_IO_MODE_KEY; + private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; + private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; + private final String SPARK_NETWORK_IO_BACKLOG_KEY; + private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY; + private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY; + private final String SPARK_NETWORK_IO_CLIENTTHREADS_KEY; + private final String SPARK_NETWORK_IO_RECEIVEBUFFER_KEY; + private final String SPARK_NETWORK_IO_SENDBUFFER_KEY; + private final String SPARK_NETWORK_SASL_TIMEOUT_KEY; + private final String SPARK_NETWORK_IO_MAXRETRIES_KEY; + private final String SPARK_NETWORK_IO_RETRYWAIT_KEY; + private final String SPARK_NETWORK_IO_LAZYFD_KEY; + private final ConfigProvider conf; - public TransportConf(ConfigProvider conf) { + private final String module; + + public TransportConf(String module, ConfigProvider conf) { + this.module = module; this.conf = conf; + SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode"); + SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs"); + SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout"); + SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog"); + SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer"); + SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads"); + SPARK_NETWORK_IO_CLIENTTHREADS_KEY = getConfKey("io.clientThreads"); + SPARK_NETWORK_IO_RECEIVEBUFFER_KEY = getConfKey("io.receiveBuffer"); + SPARK_NETWORK_IO_SENDBUFFER_KEY = getConfKey("io.sendBuffer"); + SPARK_NETWORK_SASL_TIMEOUT_KEY = getConfKey("sasl.timeout"); + SPARK_NETWORK_IO_MAXRETRIES_KEY = getConfKey("io.maxRetries"); + SPARK_NETWORK_IO_RETRYWAIT_KEY = getConfKey("io.retryWait"); + SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD"); + } + + private String getConfKey(String suffix) { + return "spark." + module + "." + suffix; } /** IO mode: nio or epoll */ - public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { - return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true); } /** Connect timeout in milliseconds. Default 120 secs. */ @@ -42,23 +77,23 @@ public int connectionTimeoutMs() { long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec( conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( - conf.get("spark.shuffle.io.connectionTimeout", defaultNetworkTimeoutS + "s")) * 1000; + conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; return (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ public int numConnectionsPerPeer() { - return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 1); + return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1); } /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ - public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + public int backLog() { return conf.getInt(SPARK_NETWORK_IO_BACKLOG_KEY, -1); } /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ - public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + public int serverThreads() { return conf.getInt(SPARK_NETWORK_IO_SERVERTHREADS_KEY, 0); } /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ - public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } /** * Receive buffer size (SO_RCVBUF). @@ -67,28 +102,28 @@ public int numConnectionsPerPeer() { * Assuming latency = 1ms, network_bandwidth = 10Gbps * buffer size should be ~ 1.25MB */ - public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + public int receiveBuf() { return conf.getInt(SPARK_NETWORK_IO_RECEIVEBUFFER_KEY, -1); } /** Send buffer size (SO_SNDBUF). */ - public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.sasl.timeout", "30s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; } /** * Max number of times we will try IO exceptions (such as connection timeouts) per request. * If set to 0, we will not do any retries. */ - public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + public int maxIORetries() { return conf.getInt(SPARK_NETWORK_IO_MAXRETRIES_KEY, 3); } /** * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. * Only relevant if maxIORetries > 0. */ public int ioRetryWaitTimeMs() { - return (int) JavaUtils.timeStringAsSec(conf.get("spark.shuffle.io.retryWait", "5s")) * 1000; + return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_IO_RETRYWAIT_KEY, "5s")) * 1000; } /** @@ -101,11 +136,11 @@ public int memoryMapBytes() { } /** - * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * Whether to initialize FileDescriptor lazily or not. If true, file descriptors are * created only when data is going to be transferred. This can reduce the number of open files. */ public boolean lazyFileDescriptor() { - return conf.getBoolean("spark.shuffle.io.lazyFD", true); + return conf.getBoolean(SPARK_NETWORK_IO_LAZYFD_KEY, true); } /** diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dfb7740344ed..dc5fa1cee69b 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -83,7 +83,7 @@ public static void setUp() throws Exception { fp.write(fileContent); fp.close(); - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 84ebb337e6d5..42955ef69235 100644 --- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -60,7 +60,7 @@ public class RequestTimeoutIntegrationSuite { public void setUp() throws Exception { Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.connectionTimeout", "2s"); - conf = new TransportConf(new MapConfigProvider(configMap)); + conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @Override diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 64b457b4b3f0..8eb56bdd9846 100644 --- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -49,7 +49,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); rpcHandler = new RpcHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 6dcec831dec7..00158fd08162 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -89,7 +89,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index f44713741930..dac7d4a5b0a0 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -52,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -76,7 +76,7 @@ private void testClientReuse(final int maxConnections, boolean concurrent) Map configMap = Maps.newHashMap(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); - TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); @@ -182,7 +182,7 @@ public void closeBlockClientsWithFactory() throws IOException { @Test public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { + TransportConf conf = new TransportConf("shuffle", new ConfigProvider() { @Override public String get(String name) { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 3469e84e7f4d..b14689967018 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -207,7 +207,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -242,7 +242,7 @@ public void testFileRegionEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -368,7 +368,7 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index c393a5e1e681..1c2fa4d0d462 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -70,7 +70,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 3c6cb367dea4..a9958232a1d2 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { static TestShuffleDataContext dataContext; - static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @BeforeClass public static void beforeAll() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 2f4f1d0df478..532d7ab8d01b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -35,7 +35,7 @@ public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); @Test public void noCleanupAndCleanup() throws IOException { diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a3f9a38b1aeb..2095f41d79c1 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -91,7 +91,7 @@ public static void beforeAll() throws IOException { dataContext1.create(); dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf(new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aa99efda9494..08ddb3755bd0 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,7 +39,7 @@ public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); TransportServer server; @Before diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 06e46f924109..3a6ef0d3f847 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -254,7 +254,7 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 11ea7f3fd3cf..ba6d30a74c67 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -120,7 +120,7 @@ protected void serviceInit(Configuration conf) { registeredExecutorFile = findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); - TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); From 09ad9533d5760652de59fa4830c24cb8667958ac Mon Sep 17 00:00:00 2001 From: JihongMa Date: Wed, 18 Nov 2015 13:03:37 -0800 Subject: [PATCH 773/896] [SPARK-11720][SQL][ML] Handle edge cases when count = 0 or 1 for Stats function return Double.NaN for mean/average when count == 0 for all numeric types that is converted to Double, Decimal type continue to return null. Author: JihongMa Closes #9705 from JihongMA/SPARK-11720. --- python/pyspark/sql/dataframe.py | 2 +- .../aggregate/CentralMomentAgg.scala | 2 +- .../expressions/aggregate/Kurtosis.scala | 9 +++++---- .../expressions/aggregate/Skewness.scala | 9 +++++---- .../expressions/aggregate/Stddev.scala | 18 ++++++++++++++---- .../expressions/aggregate/Variance.scala | 18 ++++++++++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 18 ++++++++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 8 files changed, 53 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ad6ad0235a90..0dd75ba7ca82 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -761,7 +761,7 @@ def describe(self, *cols): +-------+------------------+-----+ | count| 2| 2| | mean| 3.5| null| - | stddev|2.1213203435596424| NaN| + | stddev|2.1213203435596424| null| | min| 2|Alice| | max| 5| Bob| +-------+------------------+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index de5872ab11eb..d07d4c338cdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -206,7 +206,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * @param centralMoments Length `momentOrder + 1` array of central moments (un-normalized) * needed to compute the aggregate stat. */ - def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double + def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Any override final def eval(buffer: InternalRow): Any = { val n = buffer.getDouble(nOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 8fa3aac9f1a5..c2bf2cb94116 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -37,16 +37,17 @@ case class Kurtosis(child: Expression, override protected val momentOrder = 4 // NOTE: this is the formula for excess kurtosis, which is default for R and SciPy - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index e1c01a5b8278..9411bcea2539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -36,16 +36,17 @@ case class Skewness(child: Expression, override protected val momentOrder = 3 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) - if (n == 0.0 || m2 == 0.0) { + if (n == 0.0) { + null + } else if (m2 == 0.0) { Double.NaN - } - else { + } else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 05dd5e3b2254..eec79a9033e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -36,11 +36,17 @@ case class StddevSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + math.sqrt(moments(2) / (n - 1.0)) + } } } @@ -62,10 +68,14 @@ case class StddevPop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) + if (n == 0.0) { + null + } else { + math.sqrt(moments(2) / n) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ede2da280596..cf3a74030539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -36,11 +36,17 @@ case class VarianceSamp(child: Expression, override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) + if (n == 0.0) { + null + } else if (n == 1.0) { + Double.NaN + } else { + moments(2) / (n - 1.0) + } } } @@ -62,10 +68,14 @@ case class VariancePop( override protected val momentOrder = 2 - override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = { require(moments.length == momentOrder + 1, s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - if (n == 0.0) Double.NaN else moments(2) / n + if (n == 0.0) { + null + } else { + moments(2) / n + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 432e8d17623a..71adf2148a40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null)) } test("zero sum") { @@ -244,17 +244,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( - input.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), + var_samp('a), var_pop('a), skewness('a), kurtosis('a)), + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) checkAnswer( input.agg( + expr("stddev(a)"), + expr("stddev_samp(a)"), + expr("stddev_pop(a)"), expr("variance(a)"), expr("var_samp(a)"), expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN)) + Row(Double.NaN, Double.NaN, 0.0, Double.NaN, Double.NaN, 0.0, + Double.NaN, Double.NaN)) } test("null moments") { @@ -262,7 +268,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer( emptyTableData.agg(variance('a), var_samp('a), var_pop('a), skewness('a), kurtosis('a)), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) checkAnswer( emptyTableData.agg( @@ -271,6 +277,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("var_pop(a)"), expr("skewness(a)"), expr("kurtosis(a)")), - Row(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)) + Row(null, null, null, null, null)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5a7f24684d1b..6399b0165c4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", "NaN", "NaN"), + Row("stddev", null, null), Row("min", null, null), Row("max", null, null)) From 045a4f045821dcf60442f0600c2df1b79bddb536 Mon Sep 17 00:00:00 2001 From: Wenjian Huang Date: Wed, 18 Nov 2015 13:06:25 -0800 Subject: [PATCH 774/896] [SPARK-6790][ML] Add spark.ml LinearRegression import/export This replaces [https://github.com/apache/spark/pull/9656] with updates. fayeshine should be the main author when this PR is committed. CC: mengxr fayeshine Author: Wenjian Huang Author: Joseph K. Bradley Closes #9814 from jkbradley/fayeshine-patch-6790. --- .../ml/regression/LinearRegression.scala | 77 ++++++++++++++++++- .../ml/regression/LinearRegressionSuite.scala | 34 +++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 913140e58198..ca55d5915e68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.ml.feature.Instance @@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Logging { + with LinearRegressionParams with Writable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object LinearRegression extends Readable[LinearRegression] { + + @Since("1.6.0") + override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] + + @Since("1.6.0") + override def load(path: String): LinearRegression = read.load(path) } /** @@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Writable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] ( if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + * This also does not save the [[parent]] currently. + */ + @Since("1.6.0") + override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) +} + +@Since("1.6.0") +object LinearRegressionModel extends Readable[LinearRegressionModel] { + + @Since("1.6.0") + override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + + @Since("1.6.0") + override def load(path: String): LinearRegressionModel = read.load(path) + + /** [[Writer]] instance for [[LinearRegressionModel]] */ + private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) + extends Writer with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } + } + + private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + + override def load(path: String): LinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index a1d86fe8feda..2bdc0e184d73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,14 +22,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ @@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) } } + + test("read/write") { + def checkModelData(model: LinearRegressionModel, model2: LinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + } + val lr = new LinearRegression() + testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, + checkModelData) + } +} + +object LinearRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "regParam" -> 0.01, + "elasticNetParam" -> 0.1, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "solver" -> "l-bfgs" + ) } From 2acdf10b1f3bb1242dba64efa798c672fde9f0d2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 13:16:31 -0800 Subject: [PATCH 775/896] [SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9786 from jkbradley/als-io. --- .../apache/spark/ml/recommendation/ALS.scala | 75 ++++++++++++++++-- .../org/apache/spark/ml/util/ReadWrite.scala | 14 +++- .../spark/ml/recommendation/ALSSuite.scala | 78 ++++++++++++++++--- 3 files changed, 150 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 535f266b9a94..d92514d2e239 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD @@ -182,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams { + extends Model[ALSModel] with ALSModelParams with Writable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -220,8 +223,60 @@ class ALSModel private[ml] ( val copied = new ALSModel(uid, rank, userFactors, itemFactors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new ALSModel.ALSModelWriter(this) } +@Since("1.6.0") +object ALSModel extends Readable[ALSModel] { + + @Since("1.6.0") + override def read: Reader[ALSModel] = new ALSModelReader + + @Since("1.6.0") + override def load(path: String): ALSModel = read.load(path) + + private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata = render("rank" -> instance.rank) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val userPath = new Path(path, "userFactors").toString + instance.userFactors.write.format("parquet").save(userPath) + val itemPath = new Path(path, "itemFactors").toString + instance.itemFactors.write.format("parquet").save(itemPath) + } + } + + private[recommendation] class ALSModelReader extends Reader[ALSModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.recommendation.ALSModel" + + override def load(path: String): ALSModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + implicit val format = DefaultFormats + val rank: Int = metadata.extraMetadata match { + case Some(m: JValue) => + (m \ "rank").extract[Int] + case None => + throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + + s" ${metadata.metadataStr}") + } + + val userPath = new Path(path, "userFactors").toString + val userFactors = sqlContext.read.format("parquet").load(userPath) + val itemPath = new Path(path, "itemFactors").toString + val itemFactors = sqlContext.read.format("parquet").load(itemPath) + + val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} /** * :: Experimental :: @@ -254,7 +309,7 @@ class ALSModel private[ml] ( * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { } override def copy(extra: ParamMap): ALS = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } + /** * :: DeveloperApi :: * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is @@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { * than 2 billion. */ @DeveloperApi -object ALS extends Logging { +object ALS extends Readable[ALS] with Logging { /** * :: DeveloperApi :: @@ -356,6 +415,12 @@ object ALS extends Logging { @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) + @Since("1.6.0") + override def read: Reader[ALS] = new DefaultParamsReader[ALS] + + @Since("1.6.0") + override def load(path: String): ALS = read.load(path) + /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { /** Solves a least squares problem with regularization (possibly with other constraints). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dddb72af5ba7..d8ce907af532 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ - def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { + def saveMetadata( + instance: Params, + path: String, + sc: SparkContext, + extraMetadata: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter { ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) + ("paramMap" -> jsonParams) ~ + ("extraMetadata" -> extraMetadata) val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] + * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] * @param metadataStr Full metadata file String (for debugging) */ case class Metadata( @@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, + extraMetadata: Option[JValue], metadataStr: String) /** @@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" + val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index eadc80e0e62b..2c3fb84160dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.recommendation -import java.io.File import java.util.Random import scala.collection.mutable @@ -26,28 +25,26 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.sql.{DataFrame, Row} -class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var tempDir: File = _ +class ALSSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() - tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) super.afterAll() } @@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5)) var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)] var i = 0 - while (i < compressed.srcIds.size) { + while (i < compressed.srcIds.length) { var j = compressed.dstPtrs(i) while (j < compressed.dstPtrs(i + 1)) { val dstEncodedIndex = compressed.dstEncodedIndices(j) @@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true, seed = 0) } + + test("read/write") { + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val als = new ALS() + allEstimatorParamSettings.foreach { case (p, v) => + als.set(als.getParam(p), v) + } + val sqlContext = this.sqlContext + import sqlContext.implicits._ + val model = als.fit(ratings.toDF()) + + // Test Estimator save/load + val als2 = testDefaultReadWrite(als) + allEstimatorParamSettings.foreach { case (p, v) => + val param = als.getParam(p) + assert(als.get(param).get === als2.get(param).get) + } + + // Test Model save/load + val model2 = testDefaultReadWrite(model) + allModelParamSettings.foreach { case (p, v) => + val param = model.getParam(p) + assert(model.get(param).get === model2.get(param).get) + } + assert(model.rank === model2.rank) + def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { + df.select("id", "features").collect().map { case r => + (r.getInt(0), r.getAs[Array[Float]](1)) + }.toSet + } + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } +} + +object ALSSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allModelParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPredictionCol" + ) + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map( + "maxIter" -> 1, + "rank" -> 1, + "regParam" -> 0.01, + "numUserBlocks" -> 2, + "numItemBlocks" -> 2, + "implicitPrefs" -> true, + "alpha" -> 0.9, + "nonnegative" -> true, + "checkpointInterval" -> 20 + ) } From e391abdf2cb6098a35347bd123b815ee9ac5b689 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 13:25:15 -0800 Subject: [PATCH 776/896] [SPARK-11813][MLLIB] Avoid serialization of vocab in Word2Vec jira: https://issues.apache.org/jira/browse/SPARK-11813 I found the problem during training a large corpus. Avoid serialization of vocab in Word2Vec has 2 benefits. 1. Performance improvement for less serialization. 2. Increase the capacity of Word2Vec a lot. Currently in the fit of word2vec, the closure mainly includes serialization of Word2Vec and 2 global table. the main part of Word2vec is the vocab of size: vocab * 40 * 2 * 4 = 320 vocab 2 global table: vocab * vectorSize * 8. If vectorSize = 20, that's 160 vocab. Their sum cannot exceed Int.max due to the restriction of ByteArrayOutputStream. In any case, avoiding serialization of vocab helps decrease the size of the closure serialization, especially when vectorSize is small, thus to allow larger vocabulary. Actually there's another possible fix, make local copy of fields to avoid including Word2Vec in the closure. Let me know if that's preferred. Author: Yuhao Yang Closes #9803 from hhbyyh/w2vVocab. --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f3e4d346e358..7ab0d89d23a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -145,8 +145,8 @@ class Word2Vec extends Serializable with Logging { private var trainWordsCount = 0 private var vocabSize = 0 - private var vocab: Array[VocabWord] = null - private var vocabHash = mutable.HashMap.empty[String, Int] + @transient private var vocab: Array[VocabWord] = null + @transient private var vocabHash = mutable.HashMap.empty[String, Int] private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) From e222d758499ad2609046cc1a2cc8afb45c5bccbb Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:30:29 -0800 Subject: [PATCH 777/896] [SPARK-11684][R][ML][DOC] Update SparkR glm API doc, user guide and example codes This PR includes: * Update SparkR:::glm, SparkR:::summary API docs. * Update SparkR machine learning user guide and example codes to show: * supporting feature interaction in R formula. * summary for gaussian GLM model. * coefficients for binomial GLM model. mengxr Author: Yanbo Liang Closes #9727 from yanboliang/spark-11684. --- R/pkg/R/mllib.R | 18 +++++-- docs/sparkr.md | 50 ++++++++++++++++--- .../ml/regression/LinearRegression.scala | 3 ++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f23e1c7f1fce..8d3b4388ae57 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,12 @@ setClass("PipelineModel", representation(model = "jobj")) #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter #' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @param standardize Whether to standardize features before training +#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and +#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory +#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an +#' analytical solution to the linear regression problem. The default value is "auto" +#' which means that the solver algorithm is selected automatically. #' @return a fitted MLlib model #' @rdname glm #' @export @@ -79,9 +85,15 @@ setMethod("predict", signature(object = "PipelineModel"), #' #' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param x A fitted MLlib model -#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See -#' summary.glm for more information. +#' @param object A fitted MLlib model +#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family +#' or a list with 'coefficients' component for binomial family. \cr +#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals +#' of the estimation, the 'coefficients' gives the estimated coefficients and their +#' estimated standard errors, t values and p-values. (It only available when model +#' fitted by normal solver.) \cr +#' For binomial family: the 'coefficients' gives the estimated coefficients. +#' See summary.glm for more information. \cr #' @rdname summary #' @export #' @examples diff --git a/docs/sparkr.md b/docs/sparkr.md index 437bd4756c27..a744b76be746 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,24 +286,37 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. + +The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). + +* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) +* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. + +The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. + +## Gaussian GLM model
    {% highlight r %} # Create the DataFrame df <- createDataFrame(sqlContext, iris) -# Fit a linear model over the dataset. +# Fit a gaussian GLM model over the dataset. model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") -# Model coefficients are returned in a similar format to R's native glm(). +# Model summary are returned in a similar format to R's native glm(). summary(model) +##$devianceResiduals +## Min Max +## -1.307112 1.412532 +## ##$coefficients -## Estimate -##(Intercept) 2.2513930 -##Sepal_Width 0.8035609 -##Species_versicolor 1.4587432 -##Species_virginica 1.9468169 +## Estimate Std. Error t value Pr(>|t|) +##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 +##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 +##Species_versicolor 1.458743 0.1121079 13.01195 0 +##Species_virginica 1.946817 0.100015 19.46525 0 # Make predictions based on the model. predictions <- predict(model, newData = df) @@ -317,3 +330,24 @@ head(select(predictions, "Sepal_Length", "prediction")) ##6 5.4 5.385281 {% endhighlight %}
    + +## Binomial GLM model + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, iris) +training <- filter(df, df$Species != "setosa") + +# Fit a binomial GLM model over the dataset. +model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") + +# Model coefficients are returned in a similar format to R's native glm(). +summary(model) +##$coefficients +## Estimate +##(Intercept) -13.046005 +##Sepal_Length 1.902373 +##Sepal_Width 0.404655 +{% endhighlight %} +
    diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ca55d5915e68..f7c44f0a51b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -145,6 +145,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". + * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. "normal" denotes using Normal Equation as an analytical + * solution to the linear regression problem. * The default value is "auto" which means that the solver algorithm is * selected automatically. * @group setParam From 603a721c21488e17c15c45ce1de893e6b3d02274 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 18 Nov 2015 13:32:06 -0800 Subject: [PATCH 778/896] [SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol [SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr Author: Yanbo Liang Closes #9811 from yanboliang/spark-11820. --- python/pyspark/ml/classification.py | 17 +++++++++-------- python/pyspark/ml/regression.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 603f2c7f798d..4a2982e2047f 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -36,7 +36,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, - HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, + HasWeightCol): """ Logistic regression. Currently, this class only supports binary classification. @@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.mllib.linalg import Vectors >>> df = sc.parallelize([ - ... Row(label=1.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() + >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") >>> model = lr.fit(df) >>> model.weights DenseVector([5.5...]) @@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() @@ -105,12 +106,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", - rawPredictionCol="rawPrediction", standardization=True): + rawPredictionCol="rawPrediction", standardization=True, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ threshold=0.5, thresholds=None, probabilityCol="probability", \ - rawPredictionCol="rawPrediction", standardization=True) + rawPredictionCol="rawPrediction", standardization=True, weightCol=None) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 7648bf13266b..944e648ec880 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver): + HasStandardization, HasSolver, HasWeightCol): """ Linear regression. @@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ - ... (1.0, Vectors.dense(1.0)), - ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") + ... (1.0, 2.0, Vectors.dense(1.0)), + ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) + >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 @@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -92,11 +92,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - standardization=True, solver="auto"): + standardization=True, solver="auto", weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - standardization=True, solver="auto") + standardization=True, solver="auto", weightCol=None) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs From 54db79702513e11335c33bcf3a03c59e965e6f16 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 18 Nov 2015 14:05:18 -0800 Subject: [PATCH 779/896] [SPARK-11544][SQL] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9652 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++++--- .../datasources/json/JsonSuite.scala | 36 +++++++++++++++++-- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df6..f9465157c936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178aff..f09b61e83815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 15:42:07 -0800 Subject: [PATCH 780/896] [SPARK-11810][SQL] Java-based encoder for opaque types in Datasets. This patch refactors the existing Kryo encoder expressions and adds support for Java serialization. Author: Reynold Xin Closes #9802 from rxin/SPARK-11810. --- .../scala/org/apache/spark/sql/Encoder.scala | 41 +++++++++--- .../sql/catalyst/expressions/objects.scala | 67 ++++++++++++------- .../catalyst/encoders/FlatEncoderSuite.scala | 27 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++++- 4 files changed, 130 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 79c2255641c0..1ed5111440c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -43,28 +43,49 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - */ - def kryo[T: ClassTag]: Encoder[T] = { - val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) - val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq(ser), - fromRowExpression = deser, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 489c6126f8cd..acf0da240051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -21,7 +21,7 @@ import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} +import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} @@ -517,29 +517,39 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy } } -/** Serializes an input object using Kryo serializer. */ -case class SerializeWithKryo(child: Expression) extends UnaryExpression { +/** + * Serializes an input object using a generic serializer (Kryo or Java). + * @param kryo if true, use Kryo. Otherwise, use Java. + */ +case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $kryo.serialize(${input.value}, null).array(); + ${ev.value} = $serializer.serialize(${input.value}, null).array(); } """ } @@ -548,29 +558,38 @@ case class SerializeWithKryo(child: Expression) extends UnaryExpression { } /** - * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit - * parameter because TreeNode cannot copy implicit parameters. + * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag + * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * @param kryo if true, use Kryo. Otherwise, use Java. */ -case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression { override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val input = child.gen(ctx) - val kryo = ctx.freshName("kryoSerializer") - val kryoClass = classOf[KryoSerializer].getName - val kryoInstanceClass = classOf[KryoSerializerInstance].getName - val sparkConfClass = classOf[SparkConf].getName + // Code to initialize the serializer. + val serializer = ctx.freshName("serializer") + val (serializerClass, serializerInstanceClass) = { + if (kryo) { + (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName) + } else { + (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) + } + } + val sparkConf = s"new ${classOf[SparkConf].getName}()" ctx.addMutableState( - kryoInstanceClass, - kryo, - s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + serializerInstanceClass, + serializer, + s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + // Code to serialize. + val input = child.gen(ctx) s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.javaType(dataType)}) - $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 2729db84897a..6e0322fb6e01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -76,17 +76,34 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { // Kryo encoders encodeDecodeTest( "hello", - Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + encoderFor(Encoders.kryo[String]), "kryo string") encodeDecodeTest( - new NotJavaSerializable(15), - Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object serialization") + + // Java encoders + encodeDecodeTest( + "hello", + encoderFor(Encoders.javaSerialization[String]), + "java string") + encodeDecodeTest( + new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), + "java object serialization") } +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} -class NotJavaSerializable(val value: Int) { +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[NotJavaSerializable].value + this.value == other.asInstanceOf[JavaSerializable].value } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b6db583dfe01..89d964aa3e46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -357,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } - test("kryo encoder") { + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -365,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((KryoData(1), 1L), (KryoData(2), 1L))) } - test("kryo encoder self join") { + test("Kryo encoder self join") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -375,6 +375,25 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (KryoData(2), KryoData(1)), (KryoData(2), KryoData(2)))) } + + test("Java encoder") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((JavaData(1), 1L), (JavaData(2), 1L))) + } + + ignore("Java encoder self join") { + implicit val kryoEncoder = Encoders.javaSerialization[JavaData] + val ds = Seq(JavaData(1), JavaData(2)).toDS() + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (JavaData(1), JavaData(1)), + (JavaData(1), JavaData(2)), + (JavaData(2), JavaData(1)), + (JavaData(2), JavaData(2)))) + } } @@ -406,3 +425,16 @@ class KryoData(val a: Int) { object KryoData { def apply(a: Int): KryoData = new KryoData(a) } + +/** Used to test Java encoder. */ +class JavaData(val a: Int) extends Serializable { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[JavaData].a + } + override def hashCode: Int = a + override def toString: String = s"JavaData($a)" +} + +object JavaData { + def apply(a: Int): JavaData = new JavaData(a) +} From 7e987de1770f4ab3d54bc05db8de0a1ef035941d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 15:47:49 -0800 Subject: [PATCH 781/896] [SPARK-6787][ML] add read/write to estimators under ml.feature (1) Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng Closes #9798 from mengxr/SPARK-6787. --- .../spark/ml/feature/CountVectorizer.scala | 72 +++++++++++++++-- .../org/apache/spark/ml/feature/IDF.scala | 71 ++++++++++++++++- .../spark/ml/feature/MinMaxScaler.scala | 72 +++++++++++++++-- .../spark/ml/feature/StandardScaler.scala | 78 ++++++++++++++++++- .../spark/ml/feature/StringIndexer.scala | 70 +++++++++++++++-- .../ml/feature/CountVectorizerSuite.scala | 24 +++++- .../apache/spark/ml/feature/IDFSuite.scala | 19 ++++- .../spark/ml/feature/MinMaxScalerSuite.scala | 25 +++++- .../ml/feature/StandardScalerSuite.scala | 64 +++++++++++---- .../spark/ml/feature/StringIndexerSuite.scala | 19 ++++- 10 files changed, 467 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b8506..5ff9bfb7d111 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -16,17 +16,19 @@ */ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.DataFrame import org.apache.spark.util.collection.OpenHashMap /** @@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { def this() = this(Identifiable.randomUID("cntVec")) @@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object CountVectorizer extends Readable[CountVectorizer] { + + @Since("1.6.0") + override def read: Reader[CountVectorizer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): CountVectorizer = super.load(path) } /** @@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String) */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams { + extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + + import CountVectorizerModel._ def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) @@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) } + + @Since("1.6.0") + override def write: Writer = new CountVectorizerModelWriter(this) +} + +@Since("1.6.0") +object CountVectorizerModel extends Readable[CountVectorizerModel] { + + private[CountVectorizerModel] + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + + private case class Data(vocabulary: Seq[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.vocabulary) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + + private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + + override def load(path: String): CountVectorizerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabulary") + .head() + val vocabulary = data.getAs[Seq[String]](0).toArray + val model = new CountVectorizerModel(metadata.uid, vocabulary) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + + @Since("1.6.0") + override def load(path: String): CountVectorizerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 4c36df75d8aa..53ad34ef1264 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { def this() = this(Identifiable.randomUID("idf")) @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IDF extends Readable[IDF] { + + @Since("1.6.0") + override def read: Reader[IDF] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): IDF = super.load(path) } /** @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase { + extends Model[IDFModel] with IDFBase with Writable { + + import IDFModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -117,4 +134,50 @@ class IDFModel private[ml] ( val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } + + /** Returns the IDF vector. */ + @Since("1.6.0") + def idf: Vector = idfModel.idf + + @Since("1.6.0") + override def write: Writer = new IDFModelWriter(this) +} + +@Since("1.6.0") +object IDFModel extends Readable[IDFModel] { + + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + + private case class Data(idf: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.idf) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IDFModelReader extends Reader[IDFModel] { + + private val className = "org.apache.spark.ml.feature.IDFModel" + + override def load(path: String): IDFModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("idf") + .head() + val idf = data.getAs[Vector](0) + val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[IDFModel] = new IDFModelReader + + @Since("1.6.0") + override def load(path: String): IDFModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 1b494ec8b172..24d964fae834 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} -import org.apache.spark.ml.util.Identifiable + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ @@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object MinMaxScaler extends Readable[MinMaxScaler] { + + @Since("1.6.0") + override def read: Reader[MinMaxScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): MinMaxScaler = super.load(path) } /** @@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + + import MinMaxScalerModel._ /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] ( val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new MinMaxScalerModelWriter(this) +} + +@Since("1.6.0") +object MinMaxScalerModel extends Readable[MinMaxScalerModel] { + + private[MinMaxScalerModel] + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + + private case class Data(originalMin: Vector, originalMax: Vector) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = new Data(instance.originalMin, instance.originalMax) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + + private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + + override def load(path: String): MinMaxScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) + .select("originalMin", "originalMax") + .head() + val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + + @Since("1.6.0") + override def load(path: String): MinMaxScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index f6d0b0c0e9e7..ab04e5418dd4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams { + with StandardScalerParams with Writable { def this() = this(Identifiable.randomUID("stdScal")) @@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StandardScaler extends Readable[StandardScaler] { + + @Since("1.6.0") + override def read: Reader[StandardScaler] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StandardScaler = super.load(path) } /** @@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams { + extends Model[StandardScalerModel] with StandardScalerParams with Writable { + + import StandardScalerModel._ /** Standard deviation of the StandardScalerModel */ val std: Vector = scaler.std @@ -112,6 +129,14 @@ class StandardScalerModel private[ml] ( /** Mean of the StandardScalerModel */ val mean: Vector = scaler.mean + /** Whether to scale to unit standard deviation. */ + @Since("1.6.0") + def getWithStd: Boolean = scaler.withStd + + /** Whether to center data with mean. */ + @Since("1.6.0") + def getWithMean: Boolean = scaler.withMean + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -138,4 +163,49 @@ class StandardScalerModel private[ml] ( val copied = new StandardScalerModel(uid, scaler) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new StandardScalerModelWriter(this) +} + +@Since("1.6.0") +object StandardScalerModel extends Readable[StandardScalerModel] { + + private[StandardScalerModel] + class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + + private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StandardScalerModelReader extends Reader[StandardScalerModel] { + + private val className = "org.apache.spark.ml.feature.StandardScalerModel" + + override def load(path: String): StandardScalerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = + sqlContext.read.parquet(dataPath) + .select("std", "mean", "withStd", "withMean") + .head() + // This is very likely to change in the future because withStd and withMean should be params. + val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) + val model = new StandardScalerModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + + @Since("1.6.0") + override def load(path: String): StandardScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f782a272d11d..f16f6afc002d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase { + with StringIndexerBase with Writable { def this() = this(Identifiable.randomUID("strIdx")) @@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StringIndexer extends Readable[StringIndexer] { + + @Since("1.6.0") + override def read: Reader[StringIndexer] = new DefaultParamsReader + + @Since("1.6.0") + override def load(path: String): StringIndexer = super.load(path) } /** @@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { + val labels: Array[String]) + extends Model[StringIndexerModel] with StringIndexerBase with Writable { + + import StringIndexerModel._ def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) @@ -176,6 +193,49 @@ class StringIndexerModel ( val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: StringIndexModelWriter = new StringIndexModelWriter(this) +} + +@Since("1.6.0") +object StringIndexerModel extends Readable[StringIndexerModel] { + + private[StringIndexerModel] + class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + + private case class Data(labels: Array[String]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.labels) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class StringIndexerModelReader extends Reader[StringIndexerModel] { + + private val className = "org.apache.spark.ml.feature.StringIndexerModel" + + override def load(path: String): StringIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + val model = new StringIndexerModel(metadata.uid, labels) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + + @Since("1.6.0") + override def load(path: String): StringIndexerModel = super.load(path) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index e192fa4850af..9c9999017317 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -18,14 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { test("params") { + ParamsSuite.checkParams(new CountVectorizer) ParamsSuite.checkParams(new CountVectorizerModel(Array("empty"))) } @@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features ~== expected absTol 1e-14) } } + + test("CountVectorizer read/write") { + val t = new CountVectorizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDF(0.5) + .setMinTF(3.0) + .setVocabSize(10) + testDefaultReadWrite(t) + } + + test("CountVectorizerModel read/write") { + val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTF(3.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.vocabulary === instance.vocabulary) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 08f80af03429..bc958c15857b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("IDF read/write") { + val t = new IDF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinDocFreq(5) + testDefaultReadWrite(t) + } + + test("IDFModel read/write") { + val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0))) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.idf === instance.idf) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index c04dda41eea3..09183fe65b72 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("MinMaxScaler fit basic case") { val sqlContext = new SQLContext(sc) @@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("MinMaxScaler read/write") { + val t = new MinMaxScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMax(1.0) + .setMin(-1.0) + testDefaultReadWrite(t) + } + + test("MinMaxScalerModel read/write") { + val instance = new MinMaxScalerModel( + "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0)) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMin(-1.0) + .setMax(1.0) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.originalMin === instance.originalMin) + assert(newInstance.originalMax === instance.originalMax) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 879a3ae87500..49a4b2efe0c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -19,12 +19,16 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} -class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var resWithStd: Array[Vector] = _ @@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ ) } - def assertResult(dataframe: DataFrame): Unit = { - dataframe.select("standarded_features", "expected").collect().foreach { + def assertResult(df: DataFrame): Unit = { + df.select("standardized_features", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => assert(vector1 ~== vector2 absTol 1E-5, "The vector value is not correct after standardization.") } } + test("params") { + ParamsSuite.checkParams(new StandardScaler) + val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) + ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + } + test("Standardization with default parameter") { val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") - val standardscaler0 = new StandardScaler() + val standardScaler0 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .fit(df0) - assertResult(standardscaler0.transform(df0)) + assertResult(standardScaler0.transform(df0)) } test("Standardization with setter") { @@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") - val standardscaler1 = new StandardScaler() + val standardScaler1 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(true) .fit(df1) - val standardscaler2 = new StandardScaler() + val standardScaler2 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(true) .setWithStd(false) .fit(df2) - val standardscaler3 = new StandardScaler() + val standardScaler3 = new StandardScaler() .setInputCol("features") - .setOutputCol("standarded_features") + .setOutputCol("standardized_features") .setWithMean(false) .setWithStd(false) .fit(df3) - assertResult(standardscaler1.transform(df1)) - assertResult(standardscaler2.transform(df2)) - assertResult(standardscaler3.transform(df3)) + assertResult(standardScaler1.transform(df1)) + assertResult(standardScaler2.transform(df2)) + assertResult(standardScaler3.transform(df3)) + } + + test("StandardScaler read/write") { + val t = new StandardScaler() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setWithStd(false) + .setWithMean(true) + testDefaultReadWrite(t) + } + + test("StandardScalerModel read/write") { + val oldModel = new feature.StandardScalerModel( + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) + val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.std === instance.std) + assert(newInstance.mean === instance.mean) + assert(newInstance.getWithStd === instance.getWithStd) + assert(newInstance.getWithMean === instance.getWithMean) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index be37bfb43883..749bfac74782 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,23 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexer read/write") { + val t = new StringIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + testDefaultReadWrite(t) + } + + test("StringIndexerModel read/write") { + val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setHandleInvalid("skip") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.labels === instance.labels) + } + test("IndexToString params") { val idxToStr = new IndexToString() ParamsSuite.checkParams(idxToStr) @@ -175,7 +192,7 @@ class StringIndexerSuite assert(outSchema("output").dataType === StringType) } - test("read/write") { + test("IndexToString read/write") { val t = new IndexToString() .setInputCol("myInputCol") .setOutputCol("myOutputCol") From 3a9851936ddfe5bcb6a7f364d535fac977551f5d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 15:55:41 -0800 Subject: [PATCH 782/896] [SPARK-11649] Properly set Akka frame size in SparkListenerSuite test SparkListenerSuite's _"onTaskGettingResult() called when result fetched remotely"_ test was extremely slow (1 to 4 minutes to run) and recently became extremely flaky, frequently failing with OutOfMemoryError. The root cause was the fact that this was using `System.setProperty` to set the Akka frame size, which was not actually modifying the frame size. As a result, this test would allocate much more data than necessary. The fix here is to simply use SparkConf in order to configure the frame size. Author: Josh Rosen Closes #9822 from JoshRosen/SPARK-11649. --- .../org/apache/spark/scheduler/SparkListenerSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 53102b9f1c93..84e545851f49 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -269,14 +269,15 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("onTaskGettingResult() called when result fetched remotely") { - sc = new SparkContext("local", "SparkListenerSuite") + val conf = new SparkConf().set("spark.akka.frameSize", "1") + sc = new SparkContext("local", "SparkListenerSuite", conf) val listener = new SaveTaskEvents sc.addSparkListener(listener) // Make a task whose result is larger than the akka frame size - System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + assert(akkaFrameSize === 1024 * 1024) val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } From c07a50b86254578625be777b1890ff95e832ac6e Mon Sep 17 00:00:00 2001 From: Derek Dagit Date: Wed, 18 Nov 2015 15:56:54 -0800 Subject: [PATCH 783/896] [SPARK-10930] History "Stages" page "duration" can be confusing Author: Derek Dagit Closes #9051 from d2r/spark-10930-ui-max-task-dur. --- .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ea806d09b600..2a1c3c1a50ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -145,9 +145,22 @@ private[ui] class StageTableBase( case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) - val duration = s.submissionTime.map { t => - if (finishTime > t) finishTime - t else System.currentTimeMillis - t - } + + // The submission time for a stage is misleading because it counts the time + // the stage waits to be launched. (SPARK-10930) + val taskLaunchTimes = + stageData.taskData.values.map(_.taskInfo.launchTime).filter(_ > 0) + val duration: Option[Long] = + if (taskLaunchTimes.nonEmpty) { + val startTime = taskLaunchTimes.min + if (finishTime > startTime) { + Some(finishTime - startTime) + } else { + Some(System.currentTimeMillis() - startTime) + } + } else { + None + } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val inputRead = stageData.inputBytes From 4b117121900e5f242e7c8f46a69164385f0da7cc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 18 Nov 2015 16:00:35 -0800 Subject: [PATCH 784/896] [SPARK-11495] Fix potential socket / file handle leaks that were found via static analysis The HP Fortify Opens Source Review team (https://www.hpfod.com/open-source-review-project) reported a handful of potential resource leaks that were discovered using their static analysis tool. We should fix the issues identified by their scan. Author: Josh Rosen Closes #9455 from JoshRosen/fix-potential-resource-leaks. --- .../spark/unsafe/map/BytesToBytesMap.java | 7 ++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 38 +++++++++++-------- .../streaming/JavaCustomReceiver.java | 31 +++++++-------- .../network/ChunkFetchIntegrationSuite.java | 15 ++++++-- .../shuffle/TestShuffleDataContext.java | 32 ++++++++++------ .../spark/streaming/JavaReceiverAPISuite.java | 20 ++++++---- 6 files changed, 90 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 04694dc54418..3387f9a4177c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -24,6 +24,7 @@ import java.util.LinkedList; import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -272,6 +273,7 @@ private void advanceToNextPage() { } } try { + Closeables.close(reader, /* swallowIOException = */ false); reader = spillWriters.getFirst().getReader(blockManager); recordsInPage = -1; } catch (IOException e) { @@ -318,6 +320,11 @@ public Location next() { try { reader.loadNext(); } catch (IOException e) { + try { + reader.close(); + } catch(IOException e2) { + logger.error("Error while closing spill reader", e2); + } // Scala iterator does not handle exception Platform.throwException(e); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 039e940a357e..dcb13e6581e5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -20,8 +20,7 @@ import java.io.*; import com.google.common.io.ByteStreams; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.io.Closeables; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -31,10 +30,8 @@ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -public final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { - private final File file; private InputStream in; private DataInputStream din; @@ -52,11 +49,15 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { assert (file.length() > 0); - this.file = file; final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); - this.in = blockManager.wrapForCompression(blockId, bs); - this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + try { + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + numRecordsRemaining = din.readInt(); + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; + } } @Override @@ -75,12 +76,7 @@ public void loadNext() throws IOException { ByteStreams.readFully(in, arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { - in.close(); - if (!file.delete() && file.exists()) { - logger.warn("Unable to delete spill file {}", file.getPath()); - } - in = null; - din = null; + close(); } } @@ -103,4 +99,16 @@ public int getRecordLength() { public long getKeyPrefix() { return keyPrefix; } + + @Override + public void close() throws IOException { + if (in != null) { + try { + in.close(); + } finally { + in = null; + din = null; + } + } + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 99df259b4e8e..4b50fbf59f80 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming; import com.google.common.collect.Lists; +import com.google.common.io.Closeables; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -121,23 +122,23 @@ public void onStop() { /** Create a socket connection and receive data until receiver is stopped */ private void receive() { - Socket socket = null; - String userInput = null; - try { - // connect to the server - socket = new Socket(host, port); - - BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); - - // Until stopped or connection broken continue reading - while (!isStopped() && (userInput = reader.readLine()) != null) { - System.out.println("Received data '" + userInput + "'"); - store(userInput); + Socket socket = null; + BufferedReader reader = null; + String userInput = null; + try { + // connect to the server + socket = new Socket(host, port); + reader = new BufferedReader(new InputStreamReader(socket.getInputStream())); + // Until stopped or connection broken continue reading + while (!isStopped() && (userInput = reader.readLine()) != null) { + System.out.println("Received data '" + userInput + "'"); + store(userInput); + } + } finally { + Closeables.close(reader, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - reader.close(); - socket.close(); - // Restart in an attempt to connect again when server is active again restart("Trying to connect again"); } catch(ConnectException ce) { diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index dc5fa1cee69b..50a324e29338 100644 --- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -31,6 +31,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.google.common.io.Closeables; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -78,10 +79,15 @@ public static void setUp() throws Exception { testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); - byte[] fileContent = new byte[1024]; - new Random().nextBytes(fileContent); - fp.write(fileContent); - fp.close(); + boolean shouldSuppressIOException = true; + try { + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + shouldSuppressIOException = false; + } finally { + Closeables.close(fp, shouldSuppressIOException); + } final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); @@ -117,6 +123,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { + bufferChunk.release(); server.close(); clientFactory.close(); testFile.delete(); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 3fdde054ab6c..7ac1ca128aed 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.OutputStream; +import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -60,21 +61,28 @@ public void cleanup() { public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; - OutputStream dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); - DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; - long offset = 0; - indexStream.writeLong(offset); - for (byte[] block : blocks) { - offset += block.length; - dataStream.write(block); + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); } - - dataStream.close(); - indexStream.close(); } /** Creates reducer blocks in a hash-based data format within our local dirs. */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index ec2bffd6a5b9..7a8ef9d14784 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -23,6 +23,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.junit.Assert.*; +import com.google.common.io.Closeables; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -121,14 +122,19 @@ public void onStop() { private void receive() { try { - Socket socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + Socket socket = null; + BufferedReader in = null; + try { + socket = new Socket(host, port); + in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + } finally { + Closeables.close(in, /* swallowIOException = */ true); + Closeables.close(socket, /* swallowIOException = */ true); } - in.close(); - socket.close(); } catch(ConnectException ce) { ce.printStackTrace(); restart("Could not connect", ce); From a402c92c92b2e1c85d264f6077aec8f6d6a08270 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Nov 2015 16:08:06 -0800 Subject: [PATCH 785/896] [SPARK-11814][STREAMING] Add better default checkpoint duration DStream checkpoint interval is by default set at max(10 second, batch interval). That's bad for large batch intervals where the checkpoint interval = batch interval, and RDDs get checkpointed every batch. This PR is to set the checkpoint interval of trackStateByKey to 10 * batch duration. Author: Tathagata Das Closes #9805 from tdas/SPARK-11814. --- .../streaming/dstream/TrackStateDStream.scala | 13 ++++++ .../streaming/TrackStateByKeySuite.scala | 44 ++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 98e881e6ae11..0ada1111ce30 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ /** * :: Experimental :: @@ -120,6 +121,14 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Enable automatic checkpointing */ override val mustCheckpoint = true + /** Override the default checkpoint duration */ + override def initialize(time: Time): Unit = { + if (checkpointDuration == null) { + checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER + } + super.initialize(time) + } + /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD @@ -141,3 +150,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } } } + +private[streaming] object InternalTrackStateDStream { + private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index e3072b444284..58aef74c0040 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -22,9 +22,10 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag +import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} @@ -57,6 +58,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef sc = new SparkContext(conf) } + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + test("state - get, exists, update, remove, ") { var state: StateImpl[Int] = null @@ -436,6 +443,41 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) } + test("trackStateByKey - checkpoint durations") { + val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream) + + def testCheckpointDuration( + batchDuration: Duration, + expectedCheckpointDuration: Duration, + explicitCheckpointDuration: Option[Duration] = None + ): Unit = { + try { + ssc = new StreamingContext(sc, batchDuration) + val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) + val dummyFunc = (value: Option[Int], state: State[Int]) => 0 + val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) + val internalTrackStateStream = trackStateStream invokePrivate privateMethod() + + explicitCheckpointDuration.foreach { d => + trackStateStream.checkpoint(d) + } + trackStateStream.register() + ssc.start() // should initialize all the checkpoint durations + assert(trackStateStream.checkpointDuration === null) + assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) + } finally { + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + } + } + + testCheckpointDuration(Milliseconds(100), Seconds(1)) + testCheckpointDuration(Seconds(1), Seconds(10)) + testCheckpointDuration(Seconds(10), Seconds(100)) + + testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2))) + testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) + } private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], From 921900fd06362474f8caac675803d526a0986d70 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 18 Nov 2015 16:19:00 -0800 Subject: [PATCH 786/896] [SPARK-11791] Fix flaky test in BatchedWriteAheadLogSuite stack trace of failure: ``` org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 62 times over 1.006322071 seconds. Last failure message: Argument(s) are different! Wanted: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.BatchedWriteAheadLogSuite$$anonfun$23$$anonfun$apply$mcV$sp$15.apply(WriteAheadLogSuite.scala:518) Actual invocation has different arguments: writeAheadLog.write( java.nio.HeapByteBuffer[pos=0 lim=124 cap=124], 10 ); -> at org.apache.spark.streaming.util.WriteAheadLogSuite$BlockingWriteAheadLog.write(WriteAheadLogSuite.scala:756) ``` I believe the issue was that due to a race condition, the ordering of the events could be messed up in the final ByteBuffer, therefore the comparison fails. By adding eventually between the requests, we make sure the ordering is preserved. Note that in real life situations, the ordering across threads will not matter. Another solution would be to implement a custom mockito matcher that sorts and then compares the results, but that kind of sounds like overkill to me. Let me know what you think tdas zsxwing Author: Burak Yavuz Closes #9790 from brkyvz/fix-flaky-2. --- .../spark/streaming/util/WriteAheadLogSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 7f80d6ecdbbb..eaa88ea3cd38 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -30,6 +30,7 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.ArgumentCaptor import org.mockito.Matchers.{eq => meq} import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -507,15 +508,18 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } blockingWal.allowWrite() - val buffer1 = wrapArrayArrayByte(Array(event1)) - val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) + val buffer = wrapArrayArrayByte(Array(event1)) + val queuedEvents = Set(event2, event3, event4, event5) eventually(timeout(1 second)) { assert(batchedWal.invokePrivate(queueLength()) === 0) - verify(wal, times(1)).write(meq(buffer1), meq(3L)) + verify(wal, times(1)).write(meq(buffer), meq(3L)) // the file name should be the timestamp of the last record, as events should be naturally // in order of timestamp, and we need the last element. - verify(wal, times(1)).write(meq(buffer2), meq(10L)) + val bufferCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + verify(wal, times(1)).write(bufferCaptor.capture(), meq(10L)) + val records = BatchedWriteAheadLog.deaggregate(bufferCaptor.getValue).map(byteBufferToString) + assert(records.toSet === queuedEvents) } } From 59a501359a267fbdb7689058693aa788703e54b1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 18 Nov 2015 16:48:09 -0800 Subject: [PATCH 787/896] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders Before this PR there were two things that would blow up if you called `df.as[MyClass]` if `MyClass` was defined in the REPL: - [x] Because `classForName` doesn't work on the munged names returned by `tpe.erasure.typeSymbol.asClass.fullName` - [x] Because we don't have anything to pass into the constructor for the `$outer` pointer. Note that this PR is just adding the infrastructure for working with inner classes in encoder and is not yet sufficient to make them work in the REPL. Currently, the implementation show in https://github.com/marmbrus/spark/commit/95cec7d413b930b36420724fafd829bef8c732ab is causing a bug that breaks code gen due to some interaction between janino and the `ExecutorClassLoader`. This will be addressed in a follow-up PR. Author: Michael Armbrust Closes #9602 from marmbrus/dataset-replClasses. --- .../spark/sql/catalyst/ScalaReflection.scala | 81 ++++++++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 26 +++++- .../sql/catalyst/encoders/OuterScopes.scala | 42 ++++++++++ .../catalyst/encoders/ProductEncoder.scala | 6 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateProjection.scala | 10 +-- .../codegen/GenerateSafeProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../codegen/GenerateUnsafeRowJoiner.scala | 6 +- .../sql/catalyst/expressions/literals.scala | 6 ++ .../sql/catalyst/expressions/objects.scala | 42 ++++++++-- .../encoders/ExpressionEncoderSuite.scala | 7 +- .../encoders/ProductEncoderSuite.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 8 +- .../aggregate/TypedAggregateExpression.scala | 19 ++--- 17 files changed, 193 insertions(+), 82 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 38828e59a215..59ccf356f2c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -35,17 +35,6 @@ object ScalaReflection extends ScalaReflection { // class loader of the current thread. override def mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) -} - -/** - * Support for generating catalyst schemas for scala objects. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror import universe._ @@ -53,30 +42,6 @@ trait ScalaReflection { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map - case class Schema(dataType: DataType, nullable: Boolean) - - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - } - - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } - - /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). - * - * @see SPARK-5281 - */ - def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe - /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -114,7 +79,9 @@ trait ScalaReflection { } ObjectType(cls) - case other => ObjectType(Utils.classForName(className)) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) } } @@ -640,6 +607,48 @@ trait ScalaReflection { } } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + case class Schema(dataType: DataType, nullable: Boolean) + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case Schema(s: StructType, _) => + s.toAttributes + } + + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = + ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index b977f278c5b5..456b59500847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.encoders +import java.util.concurrent.ConcurrentMap + import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.util.Utils -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ @@ -211,7 +213,9 @@ case class ExpressionEncoder[T]( * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the * given schema. */ - def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { val positionToAttribute = AttributeMap.toIndex(schema) val unbound = fromRowExpression transform { case b: BoundReference => positionToAttribute(b.ordinal) @@ -219,7 +223,23 @@ case class ExpressionEncoder[T]( val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(fromRowExpression = analyzedPlan.expressions.head.children.head) + + // In order to construct instances of inner classes (for example those declared in a REPL cell), + // we need an instance of the outer scope. This rule substitues those outer objects into + // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` + // registry. + copy(fromRowExpression = analyzedPlan.expressions.head.children.head transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + + s"to the scope that this class was defined in. " + "" + + "Try moving this class out of its parent class.") + } + + n.copy(outerPointer = Some(Literal.fromObject(outer))) + }) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala new file mode 100644 index 000000000000..a753b187bcd3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker + +object OuterScopes { + @transient + lazy val outerScopes: ConcurrentMap[String, AnyRef] = + new MapMaker().weakValues().makeMap() + + /** + * Adds a new outer scope to this context that can be used when instantiating an `inner class` + * during deserialialization. Inner classes are created when a case class is defined in the + * Spark REPL and registering the outer scope that this class was defined in allows us to create + * new instances on the spark executors. In normal use, users should not need to call this + * function. + * + * Warning: this function operates on the assumption that there is only ever one instance of any + * given wrapper class. + */ + def addOuterScope(outer: AnyRef): Unit = { + outerScopes.putIfAbsent(outer.getClass.getName, outer) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 55c4ee11b20f..2914c6ee790c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag object ProductEncoder { import ScalaReflection.universe._ + import ScalaReflection.mirror import ScalaReflection.localTypeOf import ScalaReflection.dataTypeFor import ScalaReflection.Schema @@ -420,8 +421,7 @@ object ProductEncoder { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -429,7 +429,7 @@ object ProductEncoder { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index d51a8dede7f3..a31574c251af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -34,7 +34,7 @@ trait CodegenFallback extends Expression { val objectTerm = ctx.freshName("obj") s""" /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 4b66069b5f55..40189f087776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -82,7 +82,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificMutableProjection(expr); } @@ -109,7 +109,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu return (InternalRow) mutableRow; } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index c0d313b2e130..f229f2000d8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -167,7 +167,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ${initMutableStates(ctx)} } - public Object apply(Object r) { + public java.lang.Object apply(java.lang.Object r) { // GenerateProjection does not work with UnsafeRows. assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); @@ -186,14 +186,14 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object genericGet(int i) { + public java.lang.Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases } return null; } - public void update(int i, Object value) { + public void update(int i, java.lang.Object value) { if (value == null) { setNullAt(i); return; @@ -212,7 +212,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return result; } - public boolean equals(Object other) { + public boolean equals(java.lang.Object other) { if (other instanceof SpecificRow) { SpecificRow row = (SpecificRow) other; $columnChecks @@ -222,7 +222,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; + java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; ${copyColumns} return new ${classOf[GenericInternalRow].getName}(arr); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f0ed8645d923..b7926bda3de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -148,7 +148,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) val code = s""" - public Object generate($exprType[] expr) { + public java.lang.Object generate($exprType[] expr) { return new SpecificSafeProjection(expr); } @@ -165,7 +165,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${initMutableStates(ctx)} } - public Object apply(Object _i) { + public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allExpressions return mutableRow; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4c17d02a2372..7b6c9373ebe3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -324,7 +324,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" - public Object generate($exprType[] exprs) { + public java.lang.Object generate($exprType[] exprs) { return new SpecificUnsafeProjection(exprs); } @@ -342,7 +342,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } // Scala.Function1 need this - public Object apply(Object row) { + public java.lang.Object apply(java.lang.Object row) { return apply((InternalRow) row); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da91ff29537b..da602d9b4bce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -159,7 +159,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U // ------------------------ Finally, put everything together --------------------------- // val code = s""" - |public Object generate($exprType[] exprs) { + |public java.lang.Object generate($exprType[] exprs) { | return new SpecificUnsafeRowJoiner(); |} | @@ -176,9 +176,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | buf = new byte[sizeInBytes]; | } | - | final Object obj1 = row1.getBaseObject(); + | final java.lang.Object obj1 = row1.getBaseObject(); | final long offset1 = row1.getBaseOffset(); - | final Object obj2 = row2.getBaseObject(); + | final java.lang.Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | | $copyBitset diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 455fa2427c26..e34fd49be838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -48,6 +48,12 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } + /** + * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object + * into code generation. + */ + def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index acf0da240051..f865a9408ef4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -178,6 +179,15 @@ case class Invoke( } } +object NewInstance { + def apply( + cls: Class[_], + arguments: Seq[Expression], + propagateNull: Boolean = false, + dataType: DataType): NewInstance = + new NewInstance(cls, arguments, propagateNull, dataType, None) +} + /** * Constructs a new instance of the given class, using the result of evaluating the specified * expressions as arguments. @@ -189,12 +199,15 @@ case class Invoke( * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. + * @param outerPointer If the object being constructed is an inner class the outerPointer must + * for the containing class must be specified. */ case class NewInstance( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = true, - dataType: DataType) extends Expression { + propagateNull: Boolean, + dataType: DataType, + outerPointer: Option[Literal]) extends Expression { private val className = cls.getName override def nullable: Boolean = propagateNull @@ -209,30 +222,43 @@ case class NewInstance( val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") + val outer = outerPointer.map(_.gen(ctx)) + + val setup = + s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code.mkString("")).getOrElse("")} + """.stripMargin + + val constructorCall = outer.map { gen => + s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + }.getOrElse { + s"new $className($argString)" + } + if (propagateNull) { val objNullCheck = if (ctx.defaultValue(dataType) == "null") { s"${ev.isNull} = ${ev.value} == null;" } else { "" } - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" + s""" - ${argGen.map(_.code).mkString("\n")} + $setup boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = new $className($argString); + ${ev.value} = $constructorCall; ${ev.isNull} = false; } """ } else { s""" - ${argGen.map(_.code).mkString("\n")} + $setup - $javaType ${ev.value} = new $className($argString); + $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = ${ev.value} == null; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 9fe64b4cf10e..cde0364f3dd9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql.catalyst.encoders import java.util.Arrays +import java.util.concurrent.ConcurrentMap + +import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.AttributeReference @@ -25,6 +28,8 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.ArrayType abstract class ExpressionEncoderSuite extends SparkFunSuite { + val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + protected def encodeDecodeTest[T]( input: T, encoder: ExpressionEncoder[T], @@ -32,7 +37,7 @@ abstract class ExpressionEncoderSuite extends SparkFunSuite { test(s"encode/decode for $testName: $input") { val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema).bind(schema) + val boundEncoder = encoder.resolve(schema, outers).bind(schema) val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index bc539d62c537..1798514c5c38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -53,6 +53,10 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) class ProductEncoderSuite extends ExpressionEncoderSuite { + outers.put(getClass.getName, this) + + case class InnerClass(i: Int) + productTest(InnerClass(1)) productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b644f6ad3096..bdcdc5d47cba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,7 +74,7 @@ class Dataset[T] private[sql]( /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(queryExecution.analyzed.output) + unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes) private implicit def classTag = resolvedTEncoder.clsTag @@ -375,7 +375,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - resolvedTEncoder, + resolvedTEncoder.bind(queryExecution.analyzed.output), queryExecution.analyzed.output).named :: Nil, logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 3f84e22a1025..7e5acbe8517d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -52,8 +52,10 @@ class GroupedDataset[K, T] private[sql]( private implicit val unresolvedKEncoder = encoderFor(kEncoder) private implicit val unresolvedTEncoder = encoderFor(tEncoder) - private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) - private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) + private val resolvedKEncoder = + unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) + private val resolvedTEncoder = + unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 3f2775896bb8..6ce41aaf01e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -52,8 +52,8 @@ object TypedAggregateExpression { */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], - bEncoder: ExpressionEncoder[Any], + aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. + bEncoder: ExpressionEncoder[Any], // Should be bound. cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -92,9 +92,6 @@ case class TypedAggregateExpression( // We let the dataset do the binding for us. lazy val boundA = aEncoder.get - val bAttributes = bEncoder.schema.toAttributes - lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { // todo: need a more neat way to assign the value. var i = 0 @@ -114,24 +111,24 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) - val returned = boundB.toRow(merged) + val returned = bEncoder.toRow(merged) updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result From e99d3392068bc929c900a4cc7b50e9e2b437a23a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 18 Nov 2015 18:34:01 -0800 Subject: [PATCH 788/896] [SPARK-11839][ML] refactor save/write traits * add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng Closes #9827 from mengxr/SPARK-11839. --- .../scala/org/apache/spark/ml/Pipeline.scala | 40 +++++++++---------- .../classification/LogisticRegression.scala | 29 +++++++------- .../apache/spark/ml/feature/Binarizer.scala | 12 ++---- .../apache/spark/ml/feature/Bucketizer.scala | 12 ++---- .../spark/ml/feature/CountVectorizer.scala | 22 ++++------ .../org/apache/spark/ml/feature/DCT.scala | 12 ++---- .../apache/spark/ml/feature/HashingTF.scala | 12 ++---- .../org/apache/spark/ml/feature/IDF.scala | 23 +++++------ .../apache/spark/ml/feature/Interaction.scala | 12 ++---- .../spark/ml/feature/MinMaxScaler.scala | 22 ++++------ .../org/apache/spark/ml/feature/NGram.scala | 12 ++---- .../apache/spark/ml/feature/Normalizer.scala | 12 ++---- .../spark/ml/feature/OneHotEncoder.scala | 12 ++---- .../ml/feature/PolynomialExpansion.scala | 12 ++---- .../ml/feature/QuantileDiscretizer.scala | 12 ++---- .../spark/ml/feature/SQLTransformer.scala | 13 ++---- .../spark/ml/feature/StandardScaler.scala | 22 ++++------ .../spark/ml/feature/StopWordsRemover.scala | 12 ++---- .../spark/ml/feature/StringIndexer.scala | 32 +++++---------- .../apache/spark/ml/feature/Tokenizer.scala | 24 +++-------- .../spark/ml/feature/VectorAssembler.scala | 12 ++---- .../spark/ml/feature/VectorSlicer.scala | 12 ++---- .../apache/spark/ml/recommendation/ALS.scala | 27 +++++-------- .../ml/regression/LinearRegression.scala | 30 ++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 40 ++++++++++++------- .../org/apache/spark/ml/PipelineSuite.scala | 14 +++---- .../spark/ml/util/DefaultReadWriteTest.scala | 17 ++++---- 27 files changed, 190 insertions(+), 321 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 25f0c696f42b..b0f22e042ec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -29,8 +29,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.Reader -import org.apache.spark.ml.util.Writer +import org.apache.spark.ml.util.MLReader +import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -89,7 +89,7 @@ abstract class PipelineStage extends Params with Logging { * an identity transformer. */ @Experimental -class Pipeline(override val uid: String) extends Estimator[PipelineModel] with Writable { +class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { def this() = this(Identifiable.randomUID("pipeline")) @@ -174,16 +174,16 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with W theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } - override def write: Writer = new Pipeline.PipelineWriter(this) + override def write: MLWriter = new Pipeline.PipelineWriter(this) } -object Pipeline extends Readable[Pipeline] { +object Pipeline extends MLReadable[Pipeline] { - override def read: Reader[Pipeline] = new PipelineReader + override def read: MLReader[Pipeline] = new PipelineReader - override def load(path: String): Pipeline = read.load(path) + override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends Writer { + private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,7 +191,7 @@ object Pipeline extends Readable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends Reader[Pipeline] { + private[ml] class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.Pipeline" @@ -202,7 +202,7 @@ object Pipeline extends Readable[Pipeline] { } } - /** Methods for [[Reader]] and [[Writer]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -210,7 +210,7 @@ object Pipeline extends Readable[Pipeline] { /** Check that all stages are Writable */ def validateStages(stages: Array[PipelineStage]): Unit = { stages.foreach { - case stage: Writable => // good + case stage: MLWritable => // good case other => throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" + s" because it contains a stage which does not implement Writable. Non-Writable stage:" + @@ -245,7 +245,7 @@ object Pipeline extends Readable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString - stages.zipWithIndex.foreach { case (stage: Writable, idx: Int) => + stages.zipWithIndex.foreach { case (stage: MLWritable, idx: Int) => stage.write.save(getStagePath(stage.uid, idx, stages.length, stagesDir)) } } @@ -285,7 +285,7 @@ object Pipeline extends Readable[Pipeline] { val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[Reader[PipelineStage]].load(stagePath) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) } (metadata.uid, stages) } @@ -308,7 +308,7 @@ object Pipeline extends Readable[Pipeline] { class PipelineModel private[ml] ( override val uid: String, val stages: Array[Transformer]) - extends Model[PipelineModel] with Writable with Logging { + extends Model[PipelineModel] with MLWritable with Logging { /** A Java/Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, stages: ju.List[Transformer]) = { @@ -333,18 +333,18 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } - override def write: Writer = new PipelineModel.PipelineModelWriter(this) + override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } -object PipelineModel extends Readable[PipelineModel] { +object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite - override def read: Reader[PipelineModel] = new PipelineModelReader + override def read: MLReader[PipelineModel] = new PipelineModelReader - override def load(path: String): PipelineModel = read.load(path) + override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends Writer { + private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,7 +352,7 @@ object PipelineModel extends Readable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends Reader[PipelineModel] { + private[ml] class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.PipelineModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 71c2533bcbf4..a3cc49f7f018 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,9 +29,9 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas @Experimental class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] - with LogisticRegressionParams with Writable with Logging { + with LogisticRegressionParams with DefaultParamsWritable with Logging { def this() = this(Identifiable.randomUID("logreg")) @@ -385,12 +385,11 @@ class LogisticRegression(override val uid: String) } override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) - - override def write: Writer = new DefaultParamsWriter(this) } -object LogisticRegression extends Readable[LogisticRegression] { - override def read: Reader[LogisticRegression] = new DefaultParamsReader[LogisticRegression] +object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { + + override def load(path: String): LogisticRegression = super.load(path) } /** @@ -403,7 +402,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams with Writable { + with LogisticRegressionParams with MLWritable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -519,26 +518,26 @@ class LogisticRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ - override def write: Writer = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } -object LogisticRegressionModel extends Readable[LogisticRegressionModel] { +object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { - override def read: Reader[LogisticRegressionModel] = new LogisticRegressionModelReader + override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader - override def load(path: String): LogisticRegressionModel = read.load(path) + override def load(path: String): LogisticRegressionModel = super.load(path) - /** [[Writer]] instance for [[LogisticRegressionModel]] */ + /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data( numClasses: Int, @@ -558,7 +557,7 @@ object LogisticRegressionModel extends Readable[LogisticRegressionModel] { } private[classification] class LogisticRegressionModelReader - extends Reader[LogisticRegressionModel] { + extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e2be6547d8f0..63c06581482e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with Writable with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,17 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Binarizer extends Readable[Binarizer] { - - @Since("1.6.0") - override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] +object Binarizer extends DefaultParamsReadable[Binarizer] { @Since("1.6.0") - override def load(path: String): Binarizer = read.load(path) + override def load(path: String): Binarizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 7095fbd70aa0..324353a96afb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,12 +93,9 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } -object Bucketizer extends Readable[Bucketizer] { +object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { @@ -140,8 +137,5 @@ object Bucketizer extends Readable[Bucketizer] { } @Since("1.6.0") - override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] - - @Since("1.6.0") - override def load(path: String): Bucketizer = read.load(path) + override def load(path: String): Bucketizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5ff9bfb7d111..4969cf42450d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -107,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit */ @Experimental class CountVectorizer(override val uid: String) - extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("cntVec")) @@ -171,16 +171,10 @@ class CountVectorizer(override val uid: String) } override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object CountVectorizer extends Readable[CountVectorizer] { - - @Since("1.6.0") - override def read: Reader[CountVectorizer] = new DefaultParamsReader +object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { @Since("1.6.0") override def load(path: String): CountVectorizer = super.load(path) @@ -193,7 +187,7 @@ object CountVectorizer extends Readable[CountVectorizer] { */ @Experimental class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) - extends Model[CountVectorizerModel] with CountVectorizerParams with Writable { + extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ @@ -251,14 +245,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin } @Since("1.6.0") - override def write: Writer = new CountVectorizerModelWriter(this) + override def write: MLWriter = new CountVectorizerModelWriter(this) } @Since("1.6.0") -object CountVectorizerModel extends Readable[CountVectorizerModel] { +object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private[CountVectorizerModel] - class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer { + class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter { private case class Data(vocabulary: Seq[String]) @@ -270,7 +264,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } } - private class CountVectorizerModelReader extends Reader[CountVectorizerModel] { + private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { private val className = "org.apache.spark.ml.feature.CountVectorizerModel" @@ -288,7 +282,7 @@ object CountVectorizerModel extends Readable[CountVectorizerModel] { } @Since("1.6.0") - override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader + override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader @Since("1.6.0") override def load(path: String): CountVectorizerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6ea5a616173e..6bed72164a1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] with Writable { + extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("dct")) @@ -69,17 +69,11 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object DCT extends Readable[DCT] { - - @Since("1.6.0") - override def read: Reader[DCT] = new DefaultParamsReader[DCT] +object DCT extends DefaultParamsReadable[DCT] { @Since("1.6.0") - override def load(path: String): DCT = read.load(path) + override def load(path: String): DCT = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6d2ea675f561..9e15835429a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{ArrayType, StructType} */ @Experimental class HashingTF(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -77,17 +77,11 @@ class HashingTF(override val uid: String) } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object HashingTF extends Readable[HashingTF] { - - @Since("1.6.0") - override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] +object HashingTF extends DefaultParamsReadable[HashingTF] { @Since("1.6.0") - override def load(path: String): HashingTF = read.load(path) + override def load(path: String): HashingTF = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 53ad34ef1264..0e00ef6f2ee2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -62,7 +62,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ @Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable { +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idf")) @@ -87,16 +88,10 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa } override def copy(extra: ParamMap): IDF = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IDF extends Readable[IDF] { - - @Since("1.6.0") - override def read: Reader[IDF] = new DefaultParamsReader +object IDF extends DefaultParamsReadable[IDF] { @Since("1.6.0") override def load(path: String): IDF = super.load(path) @@ -110,7 +105,7 @@ object IDF extends Readable[IDF] { class IDFModel private[ml] ( override val uid: String, idfModel: feature.IDFModel) - extends Model[IDFModel] with IDFBase with Writable { + extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ @@ -140,13 +135,13 @@ class IDFModel private[ml] ( def idf: Vector = idfModel.idf @Since("1.6.0") - override def write: Writer = new IDFModelWriter(this) + override def write: MLWriter = new IDFModelWriter(this) } @Since("1.6.0") -object IDFModel extends Readable[IDFModel] { +object IDFModel extends MLReadable[IDFModel] { - private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer { + private[IDFModel] class IDFModelWriter(instance: IDFModel) extends MLWriter { private case class Data(idf: Vector) @@ -158,7 +153,7 @@ object IDFModel extends Readable[IDFModel] { } } - private class IDFModelReader extends Reader[IDFModel] { + private class IDFModelReader extends MLReader[IDFModel] { private val className = "org.apache.spark.ml.feature.IDFModel" @@ -176,7 +171,7 @@ object IDFModel extends Readable[IDFModel] { } @Since("1.6.0") - override def read: Reader[IDFModel] = new IDFModelReader + override def read: MLReader[IDFModel] = new IDFModelReader @Since("1.6.0") override def load(path: String): IDFModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9df6b311cc9d..2181119f04a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.types._ @Since("1.6.0") @Experimental class Interaction @Since("1.6.0") (override val uid: String) extends Transformer - with HasInputCols with HasOutputCol with Writable { + with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) @@ -224,19 +224,13 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Interaction extends Readable[Interaction] { - - @Since("1.6.0") - override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] +object Interaction extends DefaultParamsReadable[Interaction] { @Since("1.6.0") - override def load(path: String): Interaction = read.load(path) + override def load(path: String): Interaction = super.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 24d964fae834..ed24eabb5044 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -88,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H */ @Experimental class MinMaxScaler(override val uid: String) - extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("minMaxScal")) @@ -118,16 +118,10 @@ class MinMaxScaler(override val uid: String) } override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object MinMaxScaler extends Readable[MinMaxScaler] { - - @Since("1.6.0") - override def read: Reader[MinMaxScaler] = new DefaultParamsReader +object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { @Since("1.6.0") override def load(path: String): MinMaxScaler = super.load(path) @@ -147,7 +141,7 @@ class MinMaxScalerModel private[ml] ( override val uid: String, val originalMin: Vector, val originalMax: Vector) - extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable { + extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ @@ -195,14 +189,14 @@ class MinMaxScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new MinMaxScalerModelWriter(this) + override def write: MLWriter = new MinMaxScalerModelWriter(this) } @Since("1.6.0") -object MinMaxScalerModel extends Readable[MinMaxScalerModel] { +object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private[MinMaxScalerModel] - class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer { + class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends MLWriter { private case class Data(originalMin: Vector, originalMax: Vector) @@ -214,7 +208,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } } - private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] { + private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" @@ -231,7 +225,7 @@ object MinMaxScalerModel extends Readable[MinMaxScalerModel] { } @Since("1.6.0") - override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader + override def read: MLReader[MinMaxScalerModel] = new MinMaxScalerModelReader @Since("1.6.0") override def load(path: String): MinMaxScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4a17acd95199..65414ecbefbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,17 +66,11 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object NGram extends Readable[NGram] { - - @Since("1.6.0") - override def read: Reader[NGram] = new DefaultParamsReader[NGram] +object NGram extends DefaultParamsReadable[NGram] { @Since("1.6.0") - override def load(path: String): NGram = read.load(path) + override def load(path: String): NGram = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 9df6a091d505..c2d514fd9629 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class Normalizer(override val uid: String) - extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { + extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("normalizer")) @@ -56,17 +56,11 @@ class Normalizer(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Normalizer extends Readable[Normalizer] { - - @Since("1.6.0") - override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] +object Normalizer extends DefaultParamsReadable[Normalizer] { @Since("1.6.0") - override def load(path: String): Normalizer = read.load(path) + override def load(path: String): Normalizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4e2adfaafa21..d70164eaf022 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol with Writable { + with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,17 +165,11 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object OneHotEncoder extends Readable[OneHotEncoder] { - - @Since("1.6.0") - override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] +object OneHotEncoder extends DefaultParamsReadable[OneHotEncoder] { @Since("1.6.0") - override def load(path: String): OneHotEncoder = read.load(path) + override def load(path: String): OneHotEncoder = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 49415398325f..08610593fadd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("poly")) @@ -63,9 +63,6 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } /** @@ -81,7 +78,7 @@ class PolynomialExpansion(override val uid: String) * current index and increment it properly for sparse input. */ @Since("1.6.0") -object PolynomialExpansion extends Readable[PolynomialExpansion] { +object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -182,8 +179,5 @@ object PolynomialExpansion extends Readable[PolynomialExpansion] { } @Since("1.6.0") - override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] - - @Since("1.6.0") - override def load(path: String): PolynomialExpansion = read.load(path) + override def load(path: String): PolynomialExpansion = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 2da5c966d296..7bf67c6325a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,10 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { +object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ @@ -179,8 +176,5 @@ object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { } @Since("1.6.0") - override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] - - @Since("1.6.0") - override def load(path: String): QuantileDiscretizer = read.load(path) + override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c115064ff301..3a735017ba83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.StructType */ @Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer + with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) @@ -77,17 +78,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object SQLTransformer extends Readable[SQLTransformer] { - - @Since("1.6.0") - override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] +object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") - override def load(path: String): SQLTransformer = read.load(path) + override def load(path: String): SQLTransformer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index ab04e5418dd4..1f689c1da1ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -59,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with */ @Experimental class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with Writable { + with StandardScalerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stdScal")) @@ -96,16 +96,10 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StandardScaler extends Readable[StandardScaler] { - - @Since("1.6.0") - override def read: Reader[StandardScaler] = new DefaultParamsReader +object StandardScaler extends DefaultParamsReadable[StandardScaler] { @Since("1.6.0") override def load(path: String): StandardScaler = super.load(path) @@ -119,7 +113,7 @@ object StandardScaler extends Readable[StandardScaler] { class StandardScalerModel private[ml] ( override val uid: String, scaler: feature.StandardScalerModel) - extends Model[StandardScalerModel] with StandardScalerParams with Writable { + extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ @@ -165,14 +159,14 @@ class StandardScalerModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new StandardScalerModelWriter(this) + override def write: MLWriter = new StandardScalerModelWriter(this) } @Since("1.6.0") -object StandardScalerModel extends Readable[StandardScalerModel] { +object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] - class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer { + class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) @@ -184,7 +178,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } } - private class StandardScalerModelReader extends Reader[StandardScalerModel] { + private class StandardScalerModelReader extends MLReader[StandardScalerModel] { private val className = "org.apache.spark.ml.feature.StandardScalerModel" @@ -204,7 +198,7 @@ object StandardScalerModel extends Readable[StandardScalerModel] { } @Since("1.6.0") - override def read: Reader[StandardScalerModel] = new StandardScalerModelReader + override def read: MLReader[StandardScalerModel] = new StandardScalerModelReader @Since("1.6.0") override def load(path: String): StandardScalerModel = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index f1146988dcc7..318808596dc6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,17 +154,11 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StopWordsRemover extends Readable[StopWordsRemover] { - - @Since("1.6.0") - override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] +object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { @Since("1.6.0") - override def load(path: String): StopWordsRemover = read.load(path) + override def load(path: String): StopWordsRemover = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f16f6afc002d..97a2e4f6d6ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -65,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha */ @Experimental class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] - with StringIndexerBase with Writable { + with StringIndexerBase with DefaultParamsWritable { def this() = this(Identifiable.randomUID("strIdx")) @@ -93,16 +93,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod } override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object StringIndexer extends Readable[StringIndexer] { - - @Since("1.6.0") - override def read: Reader[StringIndexer] = new DefaultParamsReader +object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -122,7 +116,7 @@ object StringIndexer extends Readable[StringIndexer] { class StringIndexerModel ( override val uid: String, val labels: Array[String]) - extends Model[StringIndexerModel] with StringIndexerBase with Writable { + extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @@ -199,10 +193,10 @@ class StringIndexerModel ( } @Since("1.6.0") -object StringIndexerModel extends Readable[StringIndexerModel] { +object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] - class StringIndexModelWriter(instance: StringIndexerModel) extends Writer { + class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { private case class Data(labels: Array[String]) @@ -214,7 +208,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } } - private class StringIndexerModelReader extends Reader[StringIndexerModel] { + private class StringIndexerModelReader extends MLReader[StringIndexerModel] { private val className = "org.apache.spark.ml.feature.StringIndexerModel" @@ -232,7 +226,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { } @Since("1.6.0") - override def read: Reader[StringIndexerModel] = new StringIndexerModelReader + override def read: MLReader[StringIndexerModel] = new StringIndexerModelReader @Since("1.6.0") override def load(path: String): StringIndexerModel = super.load(path) @@ -249,7 +243,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] { */ @Experimental class IndexToString private[ml] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -316,17 +310,11 @@ class IndexToString private[ml] (override val uid: String) override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object IndexToString extends Readable[IndexToString] { - - @Since("1.6.0") - override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] +object IndexToString extends DefaultParamsReadable[IndexToString] { @Since("1.6.0") - override def load(path: String): IndexToString = read.load(path) + override def load(path: String): IndexToString = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0e4445d1e2fa..8ad7bbedaab5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class Tokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("tok")) @@ -46,19 +46,13 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object Tokenizer extends Readable[Tokenizer] { - - @Since("1.6.0") - override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] +object Tokenizer extends DefaultParamsReadable[Tokenizer] { @Since("1.6.0") - override def load(path: String): Tokenizer = read.load(path) + override def load(path: String): Tokenizer = super.load(path) } /** @@ -70,7 +64,7 @@ object Tokenizer extends Readable[Tokenizer] { */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { def this() = this(Identifiable.randomUID("regexTok")) @@ -145,17 +139,11 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object RegexTokenizer extends Readable[RegexTokenizer] { - - @Since("1.6.0") - override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] +object RegexTokenizer extends DefaultParamsReadable[RegexTokenizer] { @Since("1.6.0") - override def load(path: String): RegexTokenizer = read.load(path) + override def load(path: String): RegexTokenizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7e54205292ca..0feec0549852 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol with Writable { + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,19 +120,13 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorAssembler extends Readable[VectorAssembler] { - - @Since("1.6.0") - override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] +object VectorAssembler extends DefaultParamsReadable[VectorAssembler] { @Since("1.6.0") - override def load(path: String): VectorAssembler = read.load(path) + override def load(path: String): VectorAssembler = super.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 911582b55b57..5410a50bc2e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with Writable { + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,13 +151,10 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object VectorSlicer extends Readable[VectorSlicer] { +object VectorSlicer extends DefaultParamsReadable[VectorSlicer] { /** Return true if given feature indices are valid */ private[feature] def validIndices(indices: Array[Int]): Boolean = { @@ -174,8 +171,5 @@ object VectorSlicer extends Readable[VectorSlicer] { } @Since("1.6.0") - override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] - - @Since("1.6.0") - override def load(path: String): VectorSlicer = read.load(path) + override def load(path: String): VectorSlicer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d92514d2e239..795b73c4c212 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -185,7 +185,7 @@ class ALSModel private[ml] ( val rank: Int, @transient val userFactors: DataFrame, @transient val itemFactors: DataFrame) - extends Model[ALSModel] with ALSModelParams with Writable { + extends Model[ALSModel] with ALSModelParams with MLWritable { /** @group setParam */ def setUserCol(value: String): this.type = set(userCol, value) @@ -225,19 +225,19 @@ class ALSModel private[ml] ( } @Since("1.6.0") - override def write: Writer = new ALSModel.ALSModelWriter(this) + override def write: MLWriter = new ALSModel.ALSModelWriter(this) } @Since("1.6.0") -object ALSModel extends Readable[ALSModel] { +object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") - override def read: Reader[ALSModel] = new ALSModelReader + override def read: MLReader[ALSModel] = new ALSModelReader @Since("1.6.0") - override def load(path: String): ALSModel = read.load(path) + override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer { + private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,7 +249,7 @@ object ALSModel extends Readable[ALSModel] { } } - private[recommendation] class ALSModelReader extends Reader[ALSModel] { + private[recommendation] class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.recommendation.ALSModel" @@ -309,7 +309,8 @@ object ALSModel extends Readable[ALSModel] { * preferences rather than explicit ratings given to items. */ @Experimental -class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable { +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams + with DefaultParamsWritable { import org.apache.spark.ml.recommendation.ALS.Rating @@ -391,9 +392,6 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w } override def copy(extra: ParamMap): ALS = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @@ -406,7 +404,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams w * than 2 billion. */ @DeveloperApi -object ALS extends Readable[ALS] with Logging { +object ALS extends DefaultParamsReadable[ALS] with Logging { /** * :: DeveloperApi :: @@ -416,10 +414,7 @@ object ALS extends Readable[ALS] with Logging { case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) @Since("1.6.0") - override def read: Reader[ALS] = new DefaultParamsReader[ALS] - - @Since("1.6.0") - override def load(path: String): ALS = read.load(path) + override def load(path: String): ALS = super.load(path) /** Trait for least squares solvers applied to the normal equation. */ private[recommendation] trait LeastSquaresNESolver extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7c44f0a51b8..7ba1a60edaf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams @Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams with Writable with Logging { + with LinearRegressionParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("linReg")) @@ -345,19 +345,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String @Since("1.4.0") override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) - - @Since("1.6.0") - override def write: Writer = new DefaultParamsWriter(this) } @Since("1.6.0") -object LinearRegression extends Readable[LinearRegression] { - - @Since("1.6.0") - override def read: Reader[LinearRegression] = new DefaultParamsReader[LinearRegression] +object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") - override def load(path: String): LinearRegression = read.load(path) + override def load(path: String): LinearRegression = super.load(path) } /** @@ -371,7 +365,7 @@ class LinearRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with Writable { + with LinearRegressionParams with MLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -441,7 +435,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[Writer]] instance for this ML instance. + * Returns a [[MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -449,21 +443,21 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: Writer = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) } @Since("1.6.0") -object LinearRegressionModel extends Readable[LinearRegressionModel] { +object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") - override def read: Reader[LinearRegressionModel] = new LinearRegressionModelReader + override def read: MLReader[LinearRegressionModel] = new LinearRegressionModelReader @Since("1.6.0") - override def load(path: String): LinearRegressionModel = read.load(path) + override def load(path: String): LinearRegressionModel = super.load(path) - /** [[Writer]] instance for [[LinearRegressionModel]] */ + /** [[MLWriter]] instance for [[LinearRegressionModel]] */ private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends Writer with Logging { + extends MLWriter with Logging { private case class Data(intercept: Double, coefficients: Vector) @@ -477,7 +471,7 @@ object LinearRegressionModel extends Readable[LinearRegressionModel] { } } - private class LinearRegressionModelReader extends Reader[LinearRegressionModel] { + private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = "org.apache.spark.ml.regression.LinearRegressionModel" diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d8ce907af532..ff9322dba122 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils /** - * Trait for [[Writer]] and [[Reader]]. + * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { private var optionSQLContext: Option[SQLContext] = None @@ -64,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite with Logging { +abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -111,16 +111,16 @@ abstract class Writer extends BaseReadWrite with Logging { } /** - * Trait for classes that provide [[Writer]]. + * Trait for classes that provide [[MLWriter]]. */ @Since("1.6.0") -trait Writable { +trait MLWritable { /** - * Returns a [[Writer]] instance for this ML instance. + * Returns an [[MLWriter]] instance for this ML instance. */ @Since("1.6.0") - def write: Writer + def write: MLWriter /** * Saves this ML instance to the input path, a shortcut of `write.save(path)`. @@ -130,13 +130,18 @@ trait Writable { def save(path: String): Unit = write.save(path) } +private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => + + override def write: MLWriter = new DefaultParamsWriter(this) +} + /** * Abstract class for utility classes that can load ML instances. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -abstract class Reader[T] extends BaseReadWrite { +abstract class MLReader[T] extends BaseReadWrite { /** * Loads the ML component from the input path. @@ -149,18 +154,18 @@ abstract class Reader[T] extends BaseReadWrite { } /** - * Trait for objects that provide [[Reader]]. + * Trait for objects that provide [[MLReader]]. * @tparam T ML instance type */ @Experimental @Since("1.6.0") -trait Readable[T] { +trait MLReadable[T] { /** - * Returns a [[Reader]] instance for this class. + * Returns an [[MLReader]] instance for this class. */ @Since("1.6.0") - def read: Reader[T] + def read: MLReader[T] /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. @@ -171,13 +176,18 @@ trait Readable[T] { def load(path: String): T = read.load(path) } +private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader +} + /** - * Default [[Writer]] implementation for transformers and estimators that contain basic + * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer { +private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) @@ -218,13 +228,13 @@ private[ml] object DefaultParamsWriter { } /** - * Default [[Reader]] implementation for transformers and estimators that contain basic + * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ -private[ml] class DefaultParamsReader[T] extends Reader[T] { +private[ml] class DefaultParamsReader[T] extends MLReader[T] { override def load(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 7f5c3895acb0..12aba6bc6dbe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } -/** Used to test [[Pipeline]] with [[Writable]] stages */ -class WritableStage(override val uid: String) extends Transformer with Writable { +/** Used to test [[Pipeline]] with [[MLWritable]] stages */ +class WritableStage(override val uid: String) extends Transformer with MLWritable { final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable override def copy(extra: ParamMap): WritableStage = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) override def transform(dataset: DataFrame): DataFrame = dataset override def transformSchema(schema: StructType): StructType = schema } -object WritableStage extends Readable[WritableStage] { +object WritableStage extends MLReadable[WritableStage] { - override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage] + override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage] - override def load(path: String): WritableStage = read.load(path) + override def load(path: String): WritableStage = super.load(path) } -/** Used to test [[Pipeline]] with non-[[Writable]] stages */ +/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */ class UnWritableStage(override val uid: String) extends Transformer { final val intParam: IntParam = new IntParam(this, "intParam", "doc") diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index dd1e8acce941..84d06b43d622 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam T ML instance type * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable]( + def testDefaultReadWrite[T <: Params with MLWritable]( instance: T, testParams: Boolean = true): T = { val uid = instance.uid @@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => instance.save(path) } instance.write.overwrite().save(path) - val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] val newInstance = loader.load(path) assert(newInstance.uid === instance.uid) @@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * @tparam E Type of [[Estimator]] * @tparam M Type of [[Model]] produced by estimator */ - def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable]( + def testEstimatorAndModelReadWrite[ + E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: DataFrame, testParams: Map[String, Any], @@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => } } -class MyParams(override val uid: String) extends Params with Writable { +class MyParams(override val uid: String) extends Params with MLWritable { final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") final val intParam: IntParam = new IntParam(this, "intParam", "doc") @@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable { override def copy(extra: ParamMap): Params = defaultCopy(extra) - override def write: Writer = new DefaultParamsWriter(this) + override def write: MLWriter = new DefaultParamsWriter(this) } -object MyParams extends Readable[MyParams] { +object MyParams extends MLReadable[MyParams] { - override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] - override def load(path: String): MyParams = read.load(path) + override def load(path: String): MyParams = super.load(path) } class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext From e61367b9f9bfc8e123369d55d7ca5925568b98a7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 18 Nov 2015 18:34:36 -0800 Subject: [PATCH 789/896] [SPARK-11833][SQL] Add Java tests for Kryo/Java Dataset encoders Also added some nicer error messages for incompatible types (private types and primitive types) for Kryo/Java encoder. Author: Reynold Xin Closes #9823 from rxin/SPARK-11833. --- .../scala/org/apache/spark/sql/Encoder.scala | 69 +++++++++++------ .../encoders/EncoderErrorMessageSuite.scala | 40 ++++++++++ .../catalyst/encoders/FlatEncoderSuite.scala | 22 ++---- .../apache/spark/sql/JavaDatasetSuite.java | 75 ++++++++++++++++++- 4 files changed, 166 insertions(+), 40 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1ed5111440c8..d54f2854fb33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.lang.reflect.Modifier + import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} @@ -43,30 +45,28 @@ trait Encoder[T] extends Serializable { */ object Encoders { - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - toRowExpressions = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) /** * Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -75,6 +75,8 @@ object Encoders { * serialization. This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -83,17 +85,40 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + fromRowExpression = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } def tuple[T1, T2]( e1: Encoder[T1], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala new file mode 100644 index 000000000000..0b2a10bb04c1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders + + +class EncoderErrorMessageSuite extends SparkFunSuite { + + // Note: we also test error messages for encoders for private classes in JavaDatasetSuite. + // That is done in Java because Scala cannot create truly private classes. + + test("primitive types in encoders using Kryo serialization") { + intercept[UnsupportedOperationException] { Encoders.kryo[Int] } + intercept[UnsupportedOperationException] { Encoders.kryo[Long] } + intercept[UnsupportedOperationException] { Encoders.kryo[Char] } + } + + test("primitive types in encoders using Java serialization") { + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Int] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } + intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 6e0322fb6e01..07523d49f426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -74,24 +74,14 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { FlatEncoder[Map[Int, Map[String, Int]]], "map of map") // Kryo encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.kryo[String]), - "kryo string") - encodeDecodeTest( - new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), - "kryo object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") + encodeDecodeTest(new KryoSerializable(15), + encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") // Java encoders - encodeDecodeTest( - "hello", - encoderFor(Encoders.javaSerialization[String]), - "java string") - encodeDecodeTest( - new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), - "java object serialization") + encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") + encodeDecodeTest(new JavaSerializable(15), + encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") } /** For testing Kryo serialization based encoder. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index d9b22506fbd3..ce40dd856f67 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -24,6 +24,7 @@ import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; + import org.junit.*; import org.apache.spark.Accumulator; @@ -410,8 +411,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); Assert.assertEquals( Arrays.asList( - new Tuple4("a", 3, 3L, 2L), - new Tuple4("b", 3, 3L, 1L)), + new Tuple4<>("a", 3, 3L, 2L), + new Tuple4<>("b", 3, 3L, 1L)), agged2.collectAsList()); } @@ -437,4 +438,74 @@ public Integer finish(Integer reduction) { return reduction; } } + + public static class KryoSerializable { + String value; + + KryoSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((KryoSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + public static class JavaSerializable implements Serializable { + String value; + + JavaSerializable(String value) { + this.value = value; + } + + @Override + public boolean equals(Object other) { + return this.value.equals(((JavaSerializable) other).value); + } + + @Override + public int hashCode() { + return this.value.hashCode(); + } + } + + @Test + public void testKryoEncoder() { + Encoder encoder = Encoders.kryo(KryoSerializable.class); + List data = Arrays.asList( + new KryoSerializable("hello"), new KryoSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + @Test + public void testJavaEncoder() { + Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); + List data = Arrays.asList( + new JavaSerializable("hello"), new JavaSerializable("world")); + Dataset ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + + /** + * For testing error messages when creating an encoder on a private class. This is done + * here since we cannot create truly private classes in Scala. + */ + private static class PrivateClassTest { } + + @Test(expected = UnsupportedOperationException.class) + public void testJavaEncoderErrorMessageForPrivateClass() { + Encoders.javaSerialization(PrivateClassTest.class); + } + + @Test(expected = UnsupportedOperationException.class) + public void testKryoEncoderErrorMessageForPrivateClass() { + Encoders.kryo(PrivateClassTest.class); + } } From 6d0848b53bbe6c5acdcf5c033cd396b1ae6e293d Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 18 Nov 2015 18:38:45 -0800 Subject: [PATCH 790/896] [SPARK-11787][SQL] Improve Parquet scan performance when using flat schemas. This patch adds an alternate to the Parquet RecordReader from the parquet-mr project that is much faster for flat schemas. Instead of using the general converter mechanism from parquet-mr, this directly uses the lower level APIs from parquet-columnar and a customer RecordReader that directly assembles into UnsafeRows. This is optionally disabled and only used for supported schemas. Using the tpcds store sales table and doing a sum of increasingly more columns, the results are: For 1 Column: Before: 11.3M rows/second After: 18.2M rows/second For 2 Columns: Before: 7.2M rows/second After: 11.2M rows/second For 5 Columns: Before: 2.9M rows/second After: 4.5M rows/second Author: Nong Li Closes #9774 from nongli/parquet. --- .../apache/spark/rdd/SqlNewHadoopRDD.scala | 41 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 + .../expressions/codegen/BufferHolder.java | 32 +- .../expressions/codegen/UnsafeRowWriter.java | 20 +- .../SpecificParquetRecordReaderBase.java | 240 +++++++ .../parquet/UnsafeRowParquetRecordReader.java | 593 ++++++++++++++++++ .../parquet/CatalystRowConverter.scala | 48 +- .../parquet/ParquetFilterSuite.scala | 4 +- 8 files changed, 944 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 264dae7f3908..4d176332b69c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -20,8 +20,6 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import scala.reflect.ClassTag - import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,10 +28,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} + +import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( @@ -96,6 +96,11 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + protected val enableUnsafeRowParquetReader: Boolean = + sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -150,9 +155,31 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - private[this] var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private[this] var reader: RecordReader[Void, V] = null + + /** + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ + if (enableUnsafeRowParquetReader && + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + // TODO: move this class to sql.execution and remove this. + reader = Utils.classForName( + "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") + .newInstance().asInstanceOf[RecordReader[Void, V]] + try { + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } catch { + case e: Exception => reader = null + } + } + + if (reader == null) { + reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5ba14ebdb62a..33769363a0ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -178,6 +178,15 @@ public void pointTo(byte[] buf, int numFields, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } + /** + * Updates this UnsafeRow preserving the number of fields. + * @param buf byte array to point to + * @param sizeInBytes the number of bytes valid in the byte array + */ + public void pointTo(byte[] buf, int sizeInBytes) { + pointTo(buf, numFields, sizeInBytes); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 9c9468678065..d26b1b187c27 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -17,19 +17,28 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; /** - * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. - * - * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables - * public for ease of use. + * A helper class to manage the row buffer when construct unsafe rows. */ public class BufferHolder { - public byte[] buffer = new byte[64]; + public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; - public void grow(int neededSize) { + public BufferHolder() { + this(64); + } + + public BufferHolder(int size) { + buffer = new byte[size]; + } + + /** + * Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. + */ + public void grow(int neededSize, UnsafeRow row) { final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. @@ -41,12 +50,23 @@ public void grow(int neededSize) { Platform.BYTE_ARRAY_OFFSET, totalSize()); buffer = tmp; + if (row != null) { + row.pointTo(buffer, length * 2); + } } } + public void grow(int neededSize) { + grow(neededSize, null); + } + public void reset() { cursor = Platform.BYTE_ARRAY_OFFSET; } + public void resetTo(int offset) { + assert(offset <= buffer.length); + cursor = Platform.BYTE_ARRAY_OFFSET + offset; + } public int totalSize() { return cursor - Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 048b7749d8fb..e227c0dec974 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -35,6 +35,7 @@ public class UnsafeRowWriter { // The offset of the global buffer where we start to write this row. private int startingOffset; private int nullBitsSize; + private UnsafeRow row; public void initialize(BufferHolder holder, int numFields) { this.holder = holder; @@ -43,7 +44,7 @@ public void initialize(BufferHolder holder, int numFields) { // grow the global buffer to make sure it has enough space to write fixed-length data. final int fixedSize = nullBitsSize + 8 * numFields; - holder.grow(fixedSize); + holder.grow(fixedSize, row); holder.cursor += fixedSize; // zero-out the null bits region @@ -52,12 +53,19 @@ public void initialize(BufferHolder holder, int numFields) { } } + public void initialize(UnsafeRow row, BufferHolder holder, int numFields) { + initialize(holder, numFields); + this.row = row; + } + private void zeroOutPaddingBytes(int numBytes) { if ((numBytes & 0x07) > 0) { Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); } } + public BufferHolder holder() { return holder; } + public boolean isNullAt(int ordinal) { return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); } @@ -90,7 +98,7 @@ public void alignToWords(int numBytes) { if (remainder > 0) { final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); + holder.grow(paddingBytes, row); for (int i = 0; i < paddingBytes; i++) { Platform.putByte(holder.buffer, holder.cursor, (byte) 0); @@ -153,7 +161,7 @@ public void write(int ordinal, Decimal input, int precision, int scale) { } } else { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // zero-out the bytes Platform.putLong(holder.buffer, holder.cursor, 0L); @@ -185,7 +193,7 @@ public void write(int ordinal, UTF8String input) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -206,7 +214,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. - holder.grow(roundedSize); + holder.grow(roundedSize, row); zeroOutPaddingBytes(numBytes); @@ -222,7 +230,7 @@ public void write(int ordinal, byte[] input, int offset, int numBytes) { public void write(int ordinal, CalendarInterval input) { // grow the global buffer before writing data. - holder.grow(16); + holder.grow(16, row); // Write the months and microseconds fields of Interval to the variable length portion. Platform.putLong(holder.buffer, holder.cursor, input.months); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java new file mode 100644 index 000000000000..2ed30c1f5a8d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; +import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; +import static org.apache.parquet.hadoop.ParquetFileReader.readFooter; +import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.bytes.BytesInput; +import org.apache.parquet.bytes.BytesUtils; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.column.values.rle.RunLengthBitPackingHybridDecoder; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.hadoop.BadConfigurationException; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.ParquetInputFormat; +import org.apache.parquet.hadoop.ParquetInputSplit; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.ConfigurationUtil; +import org.apache.parquet.schema.MessageType; + +/** + * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * This class handles computing row groups, filtering on them, setting up the column readers, + * etc. + * This is heavily based on parquet-mr's RecordReader. + * TODO: move this to the parquet-mr project. There are performance benefits of doing it + * this way, albeit at a higher cost to implement. This base class is reusable. + */ +public abstract class SpecificParquetRecordReaderBase extends RecordReader { + protected Path file; + protected MessageType fileSchema; + protected MessageType requestedSchema; + protected ReadSupport readSupport; + + /** + * The total number of rows this RecordReader will eventually read. The sum of the + * rows of all the row groups. + */ + protected long totalRowCount; + + protected ParquetFileReader reader; + + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + Configuration configuration = taskAttemptContext.getConfiguration(); + ParquetInputSplit split = (ParquetInputSplit)inputSplit; + this.file = split.getPath(); + long[] rowGroupOffsets = split.getRowGroupOffsets(); + + ParquetMetadata footer; + List blocks; + + // if task.side.metadata is set, rowGroupOffsets is null + if (rowGroupOffsets == null) { + // then we need to apply the predicate push down filter + footer = readFooter(configuration, file, range(split.getStart(), split.getEnd())); + MessageType fileSchema = footer.getFileMetaData().getSchema(); + FilterCompat.Filter filter = getFilter(configuration); + blocks = filterRowGroups(filter, footer.getBlocks(), fileSchema); + } else { + // otherwise we find the row groups that were selected on the client + footer = readFooter(configuration, file, NO_FILTER); + Set offsets = new HashSet<>(); + for (long offset : rowGroupOffsets) { + offsets.add(offset); + } + blocks = new ArrayList<>(); + for (BlockMetaData block : footer.getBlocks()) { + if (offsets.contains(block.getStartingPos())) { + blocks.add(block); + } + } + // verify we found them all + if (blocks.size() != rowGroupOffsets.length) { + long[] foundRowGroupOffsets = new long[footer.getBlocks().size()]; + for (int i = 0; i < foundRowGroupOffsets.length; i++) { + foundRowGroupOffsets[i] = footer.getBlocks().get(i).getStartingPos(); + } + // this should never happen. + // provide a good error message in case there's a bug + throw new IllegalStateException( + "All the offsets listed in the split should be found in the file." + + " expected: " + Arrays.toString(rowGroupOffsets) + + " found: " + blocks + + " out of: " + Arrays.toString(foundRowGroupOffsets) + + " in range " + split.getStart() + ", " + split.getEnd()); + } + } + MessageType fileSchema = footer.getFileMetaData().getSchema(); + Map fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); + this.readSupport = getReadSupportInstance( + (Class>) getReadSupportClass(configuration)); + ReadSupport.ReadContext readContext = readSupport.init(new InitContext( + taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); + this.requestedSchema = readContext.getRequestedSchema(); + this.fileSchema = fileSchema; + this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + + @Override + public Void getCurrentKey() throws IOException, InterruptedException { + return null; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + reader = null; + } + } + + /** + * Utility classes to abstract over different way to read ints with different encodings. + * TODO: remove this layer of abstraction? + */ + abstract static class IntIterator { + abstract int nextInt() throws IOException; + } + + protected static final class ValuesReaderIntIterator extends IntIterator { + ValuesReader delegate; + + public ValuesReaderIntIterator(ValuesReader delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInteger(); + } + } + + protected static final class RLEIntIterator extends IntIterator { + RunLengthBitPackingHybridDecoder delegate; + + public RLEIntIterator(RunLengthBitPackingHybridDecoder delegate) { + this.delegate = delegate; + } + + @Override + int nextInt() throws IOException { + return delegate.readInt(); + } + } + + protected static final class NullIntIterator extends IntIterator { + @Override + int nextInt() throws IOException { return 0; } + } + + /** + * Creates a reader for definition and repetition levels, returning an optimized one if + * the levels are not needed. + */ + static protected IntIterator createRLEIterator(int maxLevel, BytesInput bytes, + ColumnDescriptor descriptor) throws IOException { + try { + if (maxLevel == 0) return new NullIntIterator(); + return new RLEIntIterator( + new RunLengthBitPackingHybridDecoder( + BytesUtils.getWidthFromMaxInt(maxLevel), + new ByteArrayInputStream(bytes.toByteArray()))); + } catch (IOException e) { + throw new IOException("could not read levels in page for col " + descriptor, e); + } + } + + private static Map> toSetMultiMap(Map map) { + Map> setMultiMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Set set = new HashSet<>(); + set.add(entry.getValue()); + setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + } + return Collections.unmodifiableMap(setMultiMap); + } + + private static Class getReadSupportClass(Configuration configuration) { + return ConfigurationUtil.getClassFromConfig(configuration, + ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); + } + + /** + * @param readSupportClass to instantiate + * @return the configured read support + */ + private static ReadSupport getReadSupportInstance( + Class> readSupportClass){ + try { + return readSupportClass.newInstance(); + } catch (InstantiationException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } catch (IllegalAccessException e) { + throw new BadConfigurationException("could not instantiate read support class", e); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java new file mode 100644 index 000000000000..8a92e489ccb7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -0,0 +1,593 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; +import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; +import static org.apache.parquet.column.ValuesType.VALUES; + +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.column.Encoding; +import org.apache.parquet.column.page.DataPage; +import org.apache.parquet.column.page.DataPageV1; +import org.apache.parquet.column.page.DataPageV2; +import org.apache.parquet.column.page.DictionaryPage; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.values.ValuesReader; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +/** + * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. + * + * This is somewhat based on parquet-mr's ColumnReader. + * + * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * All of these can be handled efficiently and easily with codegen. + */ +public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase { + /** + * Batch of unsafe rows that we assemble and the current index we've returned. Everytime this + * batch is used up (batchIdx == numBatched), we populated the batch. + */ + private UnsafeRow[] rows = new UnsafeRow[64]; + private int batchIdx = 0; + private int numBatched = 0; + + /** + * Used to write variable length columns. Same length as `rows`. + */ + private UnsafeRowWriter[] rowWriters = null; + /** + * True if the row contains variable length fields. + */ + private boolean containsVarLenFields; + + /** + * The number of bytes in the fixed length portion of the row. + */ + private int fixedSizeBytes; + + /** + * For each request column, the reader to read this column. + * columnsReaders[i] populated the UnsafeRow's attribute at i. + */ + private ColumnReader[] columnReaders; + + /** + * The number of rows that have been returned. + */ + private long rowsReturned; + + /** + * The number of rows that have been reading, including the current in flight row group. + */ + private long totalCountLoadedSoFar = 0; + + /** + * For each column, the annotated original type. + */ + private OriginalType[] originalTypes; + + /** + * The default size for varlen columns. The row grows as necessary to accommodate the + * largest column. + */ + private static final int DEFAULT_VAR_LEN_SIZE = 32; + + /** + * Implementation of RecordReader API. + */ + @Override + public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) + throws IOException, InterruptedException { + super.initialize(inputSplit, taskAttemptContext); + + /** + * Check that the requested schema is supported. + */ + if (requestedSchema.getFieldCount() == 0) { + // TODO: what does this mean? + throw new IOException("Empty request schema not supported."); + } + int numVarLenFields = 0; + originalTypes = new OriginalType[requestedSchema.getFieldCount()]; + for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { + Type t = requestedSchema.getFields().get(i); + if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { + throw new IOException("Complex types not supported."); + } + PrimitiveType primitiveType = t.asPrimitiveType(); + + originalTypes[i] = t.getOriginalType(); + + // TODO: Be extremely cautious in what is supported. Expand this. + if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL && + originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) { + throw new IOException("Unsupported type: " + t); + } + if (originalTypes[i] == OriginalType.DECIMAL && + primitiveType.getDecimalMetadata().getPrecision() > + CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + throw new IOException("Decimal with high precision is not supported."); + } + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { + throw new IOException("Int96 not supported."); + } + ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i)); + if (!fd.equals(requestedSchema.getColumns().get(i))) { + throw new IOException("Schema evolution not supported."); + } + + if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) { + ++numVarLenFields; + } + } + + /** + * Initialize rows and rowWriters. These objects are reused across all rows in the relation. + */ + int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount()); + rowByteSize += 8 * requestedSchema.getFieldCount(); + fixedSizeBytes = rowByteSize; + rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE; + containsVarLenFields = numVarLenFields > 0; + rowWriters = new UnsafeRowWriter[rows.length]; + + for (int i = 0; i < rows.length; ++i) { + rows[i] = new UnsafeRow(); + rowWriters[i] = new UnsafeRowWriter(); + BufferHolder holder = new BufferHolder(rowByteSize); + rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), + holder.buffer.length); + } + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + /** + * Decodes a batch of values into `rows`. This function is the hot path. + */ + private boolean loadBatch() throws IOException { + // no more records left + if (rowsReturned >= totalRowCount) { return false; } + checkEndOfRowGroup(); + + int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned); + rowsReturned += num; + + if (containsVarLenFields) { + for (int i = 0; i < rowWriters.length; ++i) { + rowWriters[i].holder().resetTo(fixedSizeBytes); + } + } + + for (int i = 0; i < columnReaders.length; ++i) { + switch (columnReaders[i].descriptor.getType()) { + case BOOLEAN: + decodeBooleanBatch(i, num); + break; + case INT32: + if (originalTypes[i] == OriginalType.DECIMAL) { + decodeIntAsDecimalBatch(i, num); + } else { + decodeIntBatch(i, num); + } + break; + case INT64: + Preconditions.checkState(originalTypes[i] == null + || originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeLongBatch(i, num); + break; + case FLOAT: + decodeFloatBatch(i, num); + break; + case DOUBLE: + decodeDoubleBatch(i, num); + break; + case BINARY: + decodeBinaryBatch(i, num); + break; + case FIXED_LEN_BYTE_ARRAY: + Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL, + "Unexpected original type: " + originalTypes[i]); + decodeFixedLenArrayAsDecimalBatch(i, num); + break; + case INT96: + throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); + } + numBatched = num; + batchIdx = 0; + } + return true; + } + + private void decodeBooleanBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setBoolean(col, columnReaders[col].nextBoolean()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setInt(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeIntAsDecimalBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + // Since this is stored as an INT, it is always a compact decimal. Just set it as a long. + rows[n].setLong(col, columnReaders[col].nextInt()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeLongBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setLong(col, columnReaders[col].nextLong()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFloatBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setFloat(col, columnReaders[col].nextFloat()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeDoubleBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + rows[n].setDouble(col, columnReaders[col].nextDouble()); + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeBinaryBatch(int col, int num) throws IOException { + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer(); + int len = bytes.limit() - bytes.position(); + if (originalTypes[col] == OriginalType.UTF8) { + UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len); + rowWriters[n].write(col, str); + } else { + rowWriters[n].write(col, bytes.array(), bytes.position(), len); + } + } else { + rows[n].setNullAt(col); + } + } + } + + private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException { + PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); + int precision = type.getDecimalMetadata().getPrecision(); + int scale = type.getDecimalMetadata().getScale(); + Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + "Unsupported precision."); + + for (int n = 0; n < num; ++n) { + if (columnReaders[col].next()) { + Binary v = columnReaders[col].nextBinary(); + // Constructs a `Decimal` with an unscaled `Long` value if possible. + long unscaled = CatalystRowConverter.binaryToUnscaledLong(v); + rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision); + } else { + rows[n].setNullAt(col); + } + } + } + + /** + * + * Decoder to return values from a single column. + */ + private static final class ColumnReader { + /** + * Total number of values read. + */ + private long valuesRead; + + /** + * value that indicates the end of the current page. That is, + * if valuesRead == endOfPageValueCount, we are at the end of the page. + */ + private long endOfPageValueCount; + + /** + * The dictionary, if this column has dictionary encoding. + */ + private final Dictionary dictionary; + + /** + * If true, the current page is dictionary encoded. + */ + private boolean useDictionary; + + /** + * Maximum definition level for this column. + */ + private final int maxDefLevel; + + /** + * Repetition/Definition/Value readers. + */ + private IntIterator repetitionLevelColumn; + private IntIterator definitionLevelColumn; + private ValuesReader dataColumn; + + /** + * Total number of values in this column (in this row group). + */ + private final long totalValueCount; + + /** + * Total values in the current page. + */ + private int pageValueCount; + + private final PageReader pageReader; + private final ColumnDescriptor descriptor; + + public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader) + throws IOException { + this.descriptor = descriptor; + this.pageReader = pageReader; + this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + + DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); + if (dictionaryPage != null) { + try { + this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); + this.useDictionary = true; + } catch (IOException e) { + throw new IOException("could not decode the dictionary for " + descriptor, e); + } + } else { + this.dictionary = null; + this.useDictionary = false; + } + this.totalValueCount = pageReader.getTotalValueCount(); + if (totalValueCount == 0) { + throw new IOException("totalValueCount == 0"); + } + } + + /** + * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. + */ + public boolean nextBoolean() { + if (!useDictionary) { + return dataColumn.readBoolean(); + } else { + return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); + } + } + + public int nextInt() { + if (!useDictionary) { + return dataColumn.readInteger(); + } else { + return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); + } + } + + public long nextLong() { + if (!useDictionary) { + return dataColumn.readLong(); + } else { + return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); + } + } + + public float nextFloat() { + if (!useDictionary) { + return dataColumn.readFloat(); + } else { + return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); + } + } + + public double nextDouble() { + if (!useDictionary) { + return dataColumn.readDouble(); + } else { + return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); + } + } + + public Binary nextBinary() { + if (!useDictionary) { + return dataColumn.readBytes(); + } else { + return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); + } + } + + /** + * Advances to the next value. Returns true if the value is non-null. + */ + private boolean next() throws IOException { + if (valuesRead >= endOfPageValueCount) { + if (valuesRead >= totalValueCount) { + // How do we get here? Throw end of stream exception? + return false; + } + readPage(); + } + ++valuesRead; + // TODO: Don't read for flat schemas + //repetitionLevel = repetitionLevelColumn.nextInt(); + return definitionLevelColumn.nextInt() == maxDefLevel; + } + + private void readPage() throws IOException { + DataPage page = pageReader.readPage(); + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }); + } + + private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset, int valueCount) + throws IOException { + this.pageValueCount = valueCount; + this.endOfPageValueCount = valuesRead + pageValueCount; + if (dataEncoding.usesDictionary()) { + if (dictionary == null) { + throw new IOException( + "could not read page in col " + descriptor + + " as the dictionary was missing for encoding " + dataEncoding); + } + this.dataColumn = dataEncoding.getDictionaryBasedValuesReader( + descriptor, VALUES, dictionary); + this.useDictionary = true; + } else { + this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES); + this.useDictionary = false; + } + + try { + dataColumn.initFromPage(pageValueCount, bytes, offset); + } catch (IOException e) { + throw new IOException("could not read page in col " + descriptor, e); + } + } + + private void readPageV1(DataPageV1 page) throws IOException { + ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); + ValuesReader dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL); + this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + try { + byte[] bytes = page.getBytes().toByteArray(); + rlReader.initFromPage(pageValueCount, bytes, 0); + int next = rlReader.getNextOffset(); + dlReader.initFromPage(pageValueCount, bytes, next); + next = dlReader.getNextOffset(); + initDataReader(page.getValueEncoding(), bytes, next, page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + + private void readPageV2(DataPageV2 page) throws IOException { + this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), + page.getRepetitionLevels(), descriptor); + this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(), + page.getDefinitionLevels(), descriptor); + try { + initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0, + page.getValueCount()); + } catch (IOException e) { + throw new IOException("could not read page " + page + " in col " + descriptor, e); + } + } + } + + private void checkEndOfRowGroup() throws IOException { + if (rowsReturned != totalCountLoadedSoFar) return; + PageReadStore pages = reader.readNextRowGroup(); + if (pages == null) { + throw new IOException("expecting more rows but reached last block. Read " + + rowsReturned + " out of " + totalRowCount); + } + List columns = requestedSchema.getColumns(); + columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { + columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); + } + totalCountLoadedSoFar += pages.getRowCount(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 1f653cd3d3cb..94298fae2d69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -370,35 +370,13 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = binaryToUnscaledLong(value) + val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale) } } - - private def binaryToUnscaledLong(binary: Binary): Long = { - // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here - // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without - // copying it. - val buffer = binary.toByteBuffer - val bytes = buffer.array() - val start = buffer.position() - val end = buffer.limit() - - var unscaled = 0L - var i = start - - while (i < end) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } - - val bits = 8 * (end - start) - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - unscaled - } } private class CatalystIntDictionaryAwareDecimalConverter( @@ -658,3 +636,27 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = elementConverter.start() } } + +private[parquet] object CatalystRowConverter { + def binaryToUnscaledLong(binary: Binary): Long = { + // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here + // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without + // copying it. + val buffer = binary.toByteBuffer + val bytes = buffer.array() + val start = buffer.position() + val end = buffer.limit() + + var unscaled = 0L + var i = start + + while (i < end) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } + + val bits = 8 * (end - start) + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + unscaled + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 458786f77af3..c8028a5ef552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,7 +337,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does + // not do row by row filtering (and we probably don't want to push that). + ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => From 9c0654d36c6d171dd273850c2cc2f415cc2a5a6b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 18 Nov 2015 18:41:40 -0800 Subject: [PATCH 791/896] Revert "[SPARK-11544][SQL] sqlContext doesn't use PathFilter" This reverts commit 54db79702513e11335c33bcf3a03c59e965e6f16. --- .../apache/spark/sql/sources/interfaces.scala | 25 +++---------- .../datasources/json/JsonSuite.scala | 36 ++----------------- 2 files changed, 7 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index f9465157c936..b3d3bdf50df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,8 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -448,15 +447,9 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + logInfo(s"Listing $qualified on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) - } else { - Try(fs.listStatus(qualified)).getOrElse(Array.empty) - } + Try(fs.listStatus(qualified)).getOrElse(Array.empty) }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -854,16 +847,8 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f09b61e83815..6042b1178aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,27 +19,19 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.io.FileUtils -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TestFileFilter extends PathFilter { - override def accept(path: Path): Boolean = path.getParent.getName != "p=2" -} - class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1398,28 +1390,4 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } - - test("SPARK-11544 test pathfilter") { - withTempPath { dir => - val path = dir.getCanonicalPath - - val df = sqlContext.range(2) - df.write.json(path + "/p=1") - df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) - - val clonedConf = new Configuration(hadoopConfiguration) - try { - hadoopConfiguration.setClass( - "mapreduce.input.pathFilter.class", - classOf[TestFileFilter], - classOf[PathFilter]) - assert(sqlContext.read.json(path).count() === 2) - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - } } From 67c75828ff4df2e305bdf5d6be5a11201d1da3f3 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 18 Nov 2015 18:49:46 -0800 Subject: [PATCH 792/896] [SPARK-11816][ML] fix some style issue in ML/MLlib examples jira: https://issues.apache.org/jira/browse/SPARK-11816 Currently I only fixed some obvious comments issue like // scalastyle:off println on the bottom. Yet the style in examples is not quite consistent, like only half of the examples are with // Example usage: ./bin/run-example mllib.FPGrowthExample \, Author: Yuhao Yang Closes #9808 from hhbyyh/exampleStyle. --- .../java/org/apache/spark/examples/ml/JavaKMeansExample.java | 2 +- .../apache/spark/examples/ml/AFTSurvivalRegressionExample.scala | 2 +- .../spark/examples/ml/DecisionTreeClassificationExample.scala | 1 + .../spark/examples/ml/DecisionTreeRegressionExample.scala | 1 + .../examples/ml/MultilayerPerceptronClassifierExample.scala | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index be2bf0c7b465..47665ff2b1f3 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -41,7 +41,7 @@ * An example demonstrating a k-means clustering. * Run with *
    - * bin/run-example ml.JavaSimpleParamsExample  
    + * bin/run-example ml.JavaKMeansExample  
      * 
    */ public class JavaKMeansExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index 5da285e83681..f4b3613ccb94 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -59,4 +59,4 @@ object AFTSurvivalRegressionExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index ff8a0a90f1e4..db024b5cad93 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -90,3 +90,4 @@ object DecisionTreeClassificationExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index fc402724d215..ad01f55df72b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -78,3 +78,4 @@ object DecisionTreeRegressionExample { // $example off$ } } +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 146b83c8be49..9c98076bd24b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,4 +66,4 @@ object MultilayerPerceptronClassifierExample { sc.stop() } } -// scalastyle:off println +// scalastyle:on println From fc3f77b42d62ca789d0ee07403795978961991c7 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 18 Nov 2015 19:37:14 -0800 Subject: [PATCH 793/896] [SPARK-11614][SQL] serde parameters should be set only when all params are ready see HIVE-7975 and HIVE-12373 With changed semantic of setters in thrift objects in hive, setter should be called only after all parameters are set. It's not problem of current state but will be a problem in some day. Author: navis.ryu Closes #9580 from navis/SPARK-11614. --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f4d45714fae4..9a981d02ad67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -804,12 +804,13 @@ private[hive] case class MetastoreRelation val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo sd.setSerdeInfo(serdeInfo) + // maps and lists should be set only after all elements are ready (see HIVE-7975) serdeInfo.setSerializationLib(p.storage.serde) val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + serdeInfo.setParameters(serdeParameters) new Partition(hiveQlTable, tPartition) } From d02d5b9295b169c3ebb0967453b2835edb8a121f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 18 Nov 2015 21:44:01 -0800 Subject: [PATCH 794/896] [SPARK-11842][ML] Small cleanups to existing Readers and Writers Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley Closes #9829 from jkbradley/ml-io-cleanups. --- .../scala/org/apache/spark/ml/Pipeline.scala | 22 +++++++++++++------ .../classification/LogisticRegression.scala | 19 ++++++++++------ .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 6 ++--- .../ml/regression/LinearRegression.scala | 4 ++-- .../LogisticRegressionSuite.scala | 2 +- 10 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b0f22e042ec5..6f15b37abcb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.MLReader import org.apache.spark.ml.util.MLWriter @@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + @Since("1.6.0") override def write: MLWriter = new Pipeline.PipelineWriter(this) } +@Since("1.6.0") object Pipeline extends MLReadable[Pipeline] { + @Since("1.6.0") override def read: MLReader[Pipeline] = new PipelineReader + @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends MLReader[Pipeline] { + private class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) @@ -333,18 +337,22 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } +@Since("1.6.0") object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite + @Since("1.6.0") override def read: MLReader[PipelineModel] = new PipelineModelReader + @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends MLReader[PipelineModel] { + private class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" + private val className = classOf[PipelineModel].getName override def load(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a3cc49f7f018..418bbdc9a058 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] ( * * This also does not save the [[parent]] currently. */ + @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } +@Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + @Since("1.6.0") override def load(path: String): LogisticRegressionModel = super.load(path) /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ - private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { private case class Data( @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } - private[classification] class LogisticRegressionModelReader + private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -839,7 +844,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 4969cf42450d..b9e2144c0ad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { - private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 0e00ef6f2ee2..f7b0f29a27c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] { private class IDFModelReader extends MLReader[IDFModel] { - private val className = "org.apache.spark.ml.feature.IDFModel" + private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ed24eabb5044..c2866f5eceff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { - private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1f689c1da1ba..6d545219ebf4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private class StandardScalerModelReader extends MLReader[StandardScalerModel] { - private val className = "org.apache.spark.ml.feature.StandardScalerModel" + private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 97a2e4f6d6ca..5c40c35eeaa4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private class StringIndexerModelReader extends MLReader[StringIndexerModel] { - private val className = "org.apache.spark.ml.feature.StringIndexerModel" + private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 795b73c4c212..4d35177ad9b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] { } } - private[recommendation] class ALSModelReader extends MLReader[ALSModel] { + private class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.recommendation.ALSModel" + private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7ba1a60edaf7..70ccec766c47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48ce1bb63068..a9a6ff8a783d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -898,7 +898,7 @@ object LogisticRegressionSuite { "regParam" -> 0.01, "elasticNetParam" -> 0.1, "maxIter" -> 2, // intentionally small - "fitIntercept" -> false, + "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, "threshold" -> 0.6 From 1a93323c5bab18ed7e55bf6f7b13aae88cb9721c Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 18 Nov 2015 23:32:49 -0800 Subject: [PATCH 795/896] [SPARK-11339][SPARKR] Document the list of functions in R base package that are masked by functions with same name in SparkR Added tests for function that are reported as masked, to make sure the base:: or stats:: function can be called. For those we can't call, added them to SparkR programming guide. It would seem to me `table, sample, subset, filter, cov` not working are not actually expected - I investigated/experimented with them but couldn't get them to work. It looks like as they are defined in base or stats they are missing the S3 generic, eg. ``` > methods("transform") [1] transform,ANY-method transform.data.frame [3] transform,DataFrame-method transform.default see '?methods' for accessing help and source code > methods("subset") [1] subset.data.frame subset,DataFrame-method subset.default [4] subset.matrix see '?methods' for accessing help and source code Warning message: In .S3methods(generic.function, class, parent.frame()) : function 'subset' appears not to be S3 generic; found functions that look like S3 methods ``` Any idea? More information on masking: http://www.ats.ucla.edu/stat/r/faq/referencing_objects.htm http://www.sfu.ca/~sweldon/howTo/guide4.pdf This is what the output doc looks like (minus css): ![image](https://cloud.githubusercontent.com/assets/8969467/11229714/2946e5de-8d4d-11e5-94b0-dda9696b6fdd.png) Author: felixcheung Closes #9785 from felixcheung/rmasked. --- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/functions.R | 2 +- R/pkg/R/generics.R | 4 ++-- R/pkg/inst/tests/test_mllib.R | 5 +++++ R/pkg/inst/tests/test_sparkSQL.R | 33 +++++++++++++++++++++++++++- docs/sparkr.md | 37 +++++++++++++++++++++++++++++++- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 34177e3cdd94..06b0108b1389 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,7 +2152,7 @@ setMethod("with", }) #' Returns the column types of a DataFrame. -#' +#' #' @name coltypes #' @title Get column types of a DataFrame #' @family dataframe_funcs diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index ff0f438045c1..25a1f2210149 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2204,7 +2204,7 @@ setMethod("denseRank", #' @export #' @examples \dontrun{lag(df$c)} setMethod("lag", - signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), + signature(x = "characterOrColumn"), function(x, offset, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0dcd05438222..71004a05ba61 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -539,7 +539,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) # @rdname subset # @export -setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) +setGeneric("subset", function(x, ...) { standardGeneric("subset") }) #' @rdname agg #' @export @@ -790,7 +790,7 @@ setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) #' @rdname lag #' @export -setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) +setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last #' @export diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index d497ad8c9daa..e0667e5e22c1 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -31,6 +31,11 @@ test_that("glm and predict", { model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") prediction <- predict(model, test) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("glm should work with long formula", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d9a94faff7ac..3f4f319fe745 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -433,6 +433,10 @@ test_that("table() returns a new DataFrame", { expect_is(tabledf, "DataFrame") expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + + # Test base::table is working + #a <- letters[1:3] + #expect_equal(class(table(a, sample(a))), "table") }) test_that("toRDD() returns an RRDD", { @@ -673,6 +677,9 @@ test_that("sample on a DataFrame", { # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + + # Test base::sample is working + #expect_equal(length(sample(1:12)), 12) }) test_that("select operators", { @@ -753,6 +760,9 @@ test_that("subsetting", { df6 <- subset(df, df$age %in% c(30), c(1,2)) expect_equal(count(df6), 1) expect_equal(columns(df6), c("name", "age")) + + # Test base::subset is working + expect_equal(nrow(subset(airquality, Temp > 80, select = c(Ozone, Temp))), 68) }) test_that("selectExpr() on a DataFrame", { @@ -888,6 +898,9 @@ test_that("column functions", { expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L))) result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L))) + + # Test that stats::lag is working + expect_equal(length(lag(ldeaths, 12)), 72) }) # test_that("column binary mathfunctions", { @@ -1086,7 +1099,7 @@ test_that("group by, agg functions", { gd3_local <- collect(agg(gd3, var(df8$age))) expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) - # make sure base:: or stats::sd, var are working + # Test stats::sd, stats::var are working expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) @@ -1138,6 +1151,9 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered5), 1) filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + + # Test stats::filter is working + #expect_true(is.ts(filter(1:100, rep(1, 3)))) }) test_that("join() and merge() on a DataFrame", { @@ -1284,6 +1300,12 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { expect_is(unioned, "DataFrame") expect_equal(count(intersected), 1) expect_equal(first(intersected)$name, "Andy") + + # Test base::rbind is working + expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) + + # Test base::intersect is working + expect_equal(length(intersect(1:20, 3:23)), 18) }) test_that("withColumn() and withColumnRenamed()", { @@ -1365,6 +1387,9 @@ test_that("describe() and summarize() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[4, "name"], "Andy") expect_equal(collect(stats2)[5, "age"], "30") + + # Test base::summary is working + expect_equal(length(summary(attenu, digits = 4)), 35) }) test_that("dropna() and na.omit() on a DataFrame", { @@ -1448,6 +1473,9 @@ test_that("dropna() and na.omit() on a DataFrame", { expect_identical(expected, actual) actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + + # Test stats::na.omit is working + expect_equal(nrow(na.omit(data.frame(x = c(0, 10, NA)))), 2) }) test_that("fillna() on a DataFrame", { @@ -1510,6 +1538,9 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) result <- corr(df, "singles", "doubles", "pearson") expect_true(abs(result - 1.0) < 1e-12) + + # Test stats::cov is working + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) }) test_that("freqItems() on a DataFrame", { diff --git a/docs/sparkr.md b/docs/sparkr.md index a744b76be746..cfb9b41350f4 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -286,7 +286,7 @@ head(teenagers) # Machine Learning -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. +SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). @@ -351,3 +351,38 @@ summary(model) ##Sepal_Width 0.404655 {% endhighlight %}
    + +# R Function Name Conflicts + +When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a +function is masking another function. + +The following functions are masked by the SparkR package: + + + + + + + + + + + + + + + + + + + +
    Masked functionHow to Access
    cov in package:stats
    stats::cov(x, y = NULL, use = "everything",
    +           method = c("pearson", "kendall", "spearman"))
    filter in package:stats
    stats::filter(x, filter, method = c("convolution", "recursive"),
    +              sides = 2, circular = FALSE, init)
    sample in package:basebase::sample(x, size, replace = FALSE, prob = NULL)
    table in package:base
    base::table(...,
    +            exclude = if (useNA == "no") c(NA, NaN),
    +            useNA = c("no", "ifany", "always"),
    +            dnn = list.names(...), deparse.level = 1)
    + +You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) + From f449992009becc8f7c7f06cda522b9beaa1e263c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 10:48:04 -0800 Subject: [PATCH 796/896] [SPARK-11849][SQL] Analyzer should replace current_date and current_timestamp with literals We currently rely on the optimizer's constant folding to replace current_timestamp and current_date. However, this can still result in different values for different instances of current_timestamp/current_date if the optimizer is not running fast enough. A better solution is to replace these functions in the analyzer in one shot. Author: Reynold Xin Closes #9833 from rxin/SPARK-11849. --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 38 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f00c451b5981..84781cd57f3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -65,9 +65,8 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, - CTESubstitution :: - WindowsSubstitution :: - Nil : _*), + CTESubstitution, + WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -84,7 +83,8 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic), + PullOutNondeterministic, + ComputeCurrentTime), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -1076,7 +1076,7 @@ class Analyzer( override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. - case plan => plan transformExpressionsUp { + case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) @@ -1162,3 +1162,20 @@ object CleanupAliases extends Rule[LogicalPlan] { } } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 08586a97411a..e05106995188 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -218,4 +219,41 @@ class AnalysisSuite extends AnalysisTest { udf4) // checkUDF(udf4, expected4) } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = in.analyze.asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = in.analyze.asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } } From 962878843b611fa6229e3ee67bb22e2a4bc283cd Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 19 Nov 2015 11:02:17 -0800 Subject: [PATCH 797/896] [SPARK-11840][SQL] Restore the 1.5's behavior of planning a single distinct aggregation. The impact of this change is for a query that has a single distinct column and does not have any grouping expression like `SELECT COUNT(DISTINCT a) FROM table` The plan will be changed from ``` AGG-2 (count distinct) Shuffle to a single reducer Partial-AGG-2 (count distinct) AGG-1 (grouping on a) Shuffle by a Partial-AGG-1 (grouping on 1) ``` to the following one (1.5 uses this) ``` AGG-2 AGG-1 (grouping on a) Shuffle to a single reducer Partial-AGG-1(grouping on a) ``` The first plan is more robust. However, to better benchmark the impact of this change, we should use 1.5's plan and use the conf of `spark.sql.specializeSingleDistinctAggPlanning` to control the plan. Author: Yin Huai Closes #9828 from yhuai/distinctRewriter. --- .../sql/catalyst/analysis/DistinctAggregationRewriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index c0c960471a61..9c78f6d4cc71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,8 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { // When the flag is set to specialize single distinct agg planning, // we will rely on our Aggregation strategy to handle queries with a single - // distinct column and this aggregate operator does have grouping expressions. - distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + // distinct column. + distinctAggGroups.size > 1 } else { distinctAggGroups.size >= 1 } From 72d150c271d2b206148fd0917a0def263445121b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 19 Nov 2015 11:57:50 -0800 Subject: [PATCH 798/896] [SPARK-11830][CORE] Make NettyRpcEnv bind to the specified host This PR includes the following change: 1. Bind NettyRpcEnv to the specified host 2. Fix the port information in the log for NettyRpcEnv. 3. Fix the service name of NettyRpcEnv. Author: zsxwing Author: Shixiong Zhu Closes #9821 from zsxwing/SPARK-11830. --- .../src/main/scala/org/apache/spark/SparkEnv.scala | 9 ++++++++- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 7 +++---- .../org/apache/spark/network/TransportContext.java | 8 +++++++- .../spark/network/server/TransportServer.java | 14 ++++++++++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4474a83bedbd..88df27f733f2 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -258,8 +258,15 @@ object SparkEnv extends Logging { if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem } else { + val actorSystemPort = if (port == 0) 0 else rpcEnv.address.port + 1 // Create a ActorSystem for legacy codes - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + AkkaUtils.createActorSystem( + actorSystemName + "ActorSystem", + hostname, + actorSystemPort, + conf, + securityManager + )._1 } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3e0c49796950..3ce359868039 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -102,7 +102,7 @@ private[netty] class NettyRpcEnv( } else { java.util.Collections.emptyList() } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(host, port, bootstraps) dispatcher.registerRpcEndpoint( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } @@ -337,10 +337,10 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) - (nettyEnv, actualPort) + (nettyEnv, nettyEnv.address.port) } try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1 } catch { case NonFatal(e) => nettyEnv.shutdown() @@ -370,7 +370,6 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { * @param conf Spark configuration. * @param endpointAddress The address where the endpoint is listening. * @param nettyEnv The RpcEnv associated with this ref. - * @param local Whether the referenced endpoint lives in the same process. */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 1b64b863a9fe..238710d17249 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -94,7 +94,13 @@ public TransportClientFactory createClientFactory() { /** Create a server which will attempt to bind to a specific port. */ public TransportServer createServer(int port, List bootstraps) { - return new TransportServer(this, port, rpcHandler, bootstraps); + return new TransportServer(this, null, port, rpcHandler, bootstraps); + } + + /** Create a server which will attempt to bind to a specific host and port. */ + public TransportServer createServer( + String host, int port, List bootstraps) { + return new TransportServer(this, host, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index f4fadb1ee3b8..baae235e0220 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -55,9 +55,13 @@ public class TransportServer implements Closeable { private ChannelFuture channelFuture; private int port = -1; - /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + /** + * Creates a TransportServer that binds to the given host and the given port, or to any available + * if 0. If you don't want to bind to any special host, set "hostToBind" to null. + * */ public TransportServer( TransportContext context, + String hostToBind, int portToBind, RpcHandler appRpcHandler, List bootstraps) { @@ -67,7 +71,7 @@ public TransportServer( this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); try { - init(portToBind); + init(hostToBind, portToBind); } catch (RuntimeException e) { JavaUtils.closeQuietly(this); throw e; @@ -81,7 +85,7 @@ public int getPort() { return port; } - private void init(int portToBind) { + private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = @@ -120,7 +124,9 @@ protected void initChannel(SocketChannel ch) throws Exception { } }); - channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + InetSocketAddress address = hostToBind == null ? + new InetSocketAddress(portToBind): new InetSocketAddress(hostToBind, portToBind); + channelFuture = bootstrap.bind(address); channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); From 276a7e130252c0e7aba702ae5570b3c4f424b23b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:45:04 -0800 Subject: [PATCH 799/896] [SPARK-11633][SQL] LogicalRDD throws TreeNode Exception : Failed to Copy Node When handling self joins, the implementation did not consider the case insensitivity of HiveContext. It could cause an exception as shown in the JIRA: ``` TreeNodeException: Failed to copy node. ``` The fix is low risk. It avoids unnecessary attribute replacement. It should not affect the existing behavior of self joins. Also added the test case to cover this case. Author: gatorsmile Closes #9762 from gatorsmile/joinMakeCopy. --- .../apache/spark/sql/execution/ExistingRDD.scala | 4 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 62620ec642c7..623348f6768a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -74,6 +74,10 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil + override protected final def otherCopyArgs: Seq[AnyRef] = { + sqlContext :: Nil + } + override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6399b0165c4c..dd6d06512ff6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1110,6 +1110,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + // This test case is to verify a bug when making a new instance of LogicalRDD. + test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) + val df = sqlContext.createDataFrame( + rdd, + new StructType().add("f1", IntegerType).add("f2", IntegerType), + needsConversion = false).select($"F1", $"f2".as("f2")) + val df1 = df.as("a") + val df2 = df.as("b") + checkAnswer(df1.join(df2, $"a.f2" === $"b.f2"), Row(1, 3, 1, 3) :: Row(2, 1, 2, 1) :: Nil) + } + } + test("SPARK-10656: completely support special chars") { val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.") checkAnswer(df.select(df("*")), Row(1, "a")) From 7d4aba18722727c85893ad8d8f07d4494665dcfc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 19 Nov 2015 12:46:36 -0800 Subject: [PATCH 800/896] [SPARK-11848][SQL] Support EXPLAIN in DataSet APIs When debugging DataSet API, I always need to print the logical and physical plans. I am wondering if we should provide a simple API for EXPLAIN? Author: gatorsmile Closes #9832 from gatorsmile/explainDS. --- .../org/apache/spark/sql/DataFrame.scala | 23 +------------------ .../spark/sql/execution/Queryable.scala | 21 +++++++++++++++++ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ba4ba18d212..98358127e270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -308,27 +308,6 @@ class DataFrame private[sql]( def printSchema(): Unit = println(schema.treeString) // scalastyle:on println - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - withPlan(explain).queryExecution.executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } - - /** - * Only prints the physical plan to the console for debugging purposes. - * @group basic - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 9ca383896a09..e86a52c149a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType import scala.util.control.NonFatal @@ -25,6 +26,7 @@ import scala.util.control.NonFatal private[sql] trait Queryable { def schema: StructType def queryExecution: QueryExecution + def sqlContext: SQLContext override def toString: String = { try { @@ -34,4 +36,23 @@ private[sql] trait Queryable { s"Invalid tree; ${e.getMessage}:\n$queryExecution" } } + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Only prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + def explain(): Unit = explain(extended = false) } From 47d1c2325caaf9ffe31695b6fff529314b8582f7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Nov 2015 12:54:25 -0800 Subject: [PATCH 801/896] [SPARK-11750][SQL] revert SPARK-11727 and code clean up After some experiment, I found it's not convenient to have separate encoder builders: `FlatEncoder` and `ProductEncoder`. For example, when create encoders for `ScalaUDF`, we have no idea if the type `T` is flat or not. So I revert the splitting change in https://github.com/apache/spark/pull/9693, while still keeping the bug fixes and tests. Author: Wenchen Fan Closes #9726 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/Encoder.scala | 16 +- .../spark/sql/catalyst/ScalaReflection.scala | 354 +++++--------- .../catalyst/encoders/ExpressionEncoder.scala | 19 +- .../sql/catalyst/encoders/FlatEncoder.scala | 50 -- .../catalyst/encoders/ProductEncoder.scala | 452 ------------------ .../sql/catalyst/encoders/RowEncoder.scala | 12 +- .../sql/catalyst/expressions/objects.scala | 7 +- .../sql/catalyst/ScalaReflectionSuite.scala | 68 --- .../encoders/ExpressionEncoderSuite.scala | 218 ++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 99 ---- .../encoders/ProductEncoderSuite.scala | 156 ------ .../org/apache/spark/sql/GroupedDataset.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 23 +- .../org/apache/spark/sql/functions.scala | 4 +- 14 files changed, 364 insertions(+), 1118 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index d54f2854fb33..86bb53645903 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,14 +45,14 @@ trait Encoder[T] extends Serializable { */ object Encoders { - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 59ccf356f2c4..33ae700706da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -50,39 +50,29 @@ object ScalaReflection extends ScalaReflection { * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type * system. As a result, ObjectType will be returned for things like boxed Integers */ - def dataTypeFor(tpe: `Type`): DataType = tpe match { - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName - className match { - case "scala.Array" => - val TypeRef(_, _, Seq(arrayType)) = tpe - val cls = arrayType match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] - case other => - // There is probably a better way to do this, but I couldn't find it... - val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls - java.lang.reflect.Array.newInstance(elementType, 1).getClass + def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) - } - ObjectType(cls) - case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - ObjectType(clazz) - } + private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + tpe match { + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case _ => + val className: String = tpe.erasure.typeSymbol.asClass.fullName + className match { + case "scala.Array" => + val TypeRef(_, _, Seq(elementType)) = tpe + arrayClassFor(elementType) + case other => + val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + ObjectType(clazz) + } + } } /** @@ -90,7 +80,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - def arrayClassFor(tpe: `Type`): DataType = { + private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -108,6 +98,15 @@ object ScalaReflection extends ScalaReflection { ObjectType(cls) } + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => true + case _ => false + } + /** * Returns an expression that can be used to construct an object of type `T` given an input * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes @@ -116,63 +115,33 @@ object ScalaReflection extends ScalaReflection { * * When used on a primitive type, the constructor will instead default to extracting the value * from ordinal 0 (since there are no names to map to). The actual location can be moved by - * calling unbind/bind with a new schema. + * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None) private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = - path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = - path - .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) - /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => - getPath + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - val boxedType = optType match { - // For primitive types we must manually box the primitive value. - case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) - case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) - case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) - case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) - case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) - case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) - case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) - case _ => None - } - - boxedType.map { boxedType => - val objectType = ObjectType(boxedType) - WrapOption( - objectType, - NewInstance( - boxedType, - getPath :: Nil, - propagateNull = true, - objectType)) - }.getOrElse { - val className: String = optType.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) - val objectType = ObjectType(cls) - - WrapOption(objectType, constructorFor(optType, path)) - } + WrapOption(constructorFor(optType, path)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -231,11 +200,11 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - val primitiveMethod = elementType match { case t if t <:< definitions.IntTpe => Some("toIntArray") case t if t <:< definitions.LongTpe => Some("toLongArray") @@ -248,57 +217,52 @@ object ScalaReflection extends ScalaReflection { } primitiveMethod.map { method => - Invoke(getPath, method, dataTypeFor(t)) + Invoke(getPath, method, arrayClassFor(elementType)) }.getOrElse { - val returnType = dataTypeFor(t) Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), "array", - returnType) + arrayClassFor(elementType)) } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - - val primitiveMethodKey = keyType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } val keyData = Invoke( MapObjects( p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(keyDataType)), - keyDataType), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), "array", ObjectType(classOf[Array[Any]])) - val primitiveMethodValue = valueType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - val valueData = Invoke( MapObjects( p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(valueDataType)), - valueDataType), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), "array", ObjectType(classOf[Array[Any]])) @@ -308,40 +272,6 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - // Avoid boxing when possible by just wrapping a primitive array. - val primitiveMethod = elementType match { - case _ if nullable => None - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - val arrayData = primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), - "array", - arrayClassFor(elementType)) - } - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Product] => val formalTypeArgs = t.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = t @@ -361,8 +291,7 @@ object ScalaReflection extends ScalaReflection { } } - val className: String = t.erasure.typeSymbol.asClass.fullName - val cls = Utils.classForName(className) + val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString @@ -370,7 +299,7 @@ object ScalaReflection extends ScalaReflection { val dataType = schemaFor(fieldType).dataType // For tuples, we based grab the inner fields by ordinal instead of name. - if (className startsWith "scala.Tuple") { + if (cls.getName startsWith "scala.Tuple") { constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) } else { constructorFor(fieldType, Some(addToPath(fieldName))) @@ -388,22 +317,19 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - } } /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe) match { - case s: CreateNamedStruct => s - case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) - } + extractorFor(inputObject, localTypeOf[T]) match { + case s: CreateNamedStruct => s + case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - protected def extractorFor( + private def extractorFor( inputObject: Expression, tpe: `Type`): Expression = ScalaReflectionLock.synchronized { if (!inputObject.dataType.isInstanceOf[ObjectType]) { @@ -491,51 +417,36 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val elementDataType = dataTypeFor(elementType) - val Schema(dataType, nullable) = schemaFor(elementType) - - if (dataType.isInstanceOf[AtomicType]) { - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = ArrayType(dataType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } + toCatalystArray(inputObject, elementType) case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - val rawMap = inputObject val keys = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + val values = - NewInstance( - classOf[GenericArrayData], - Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil, - dataType = ObjectType(classOf[ArrayData])) + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) NewInstance( classOf[ArrayBasedMapData], - keys :: values :: Nil, + convertedKeys :: convertedValues :: Nil, dataType = MapType(keyDataType, valueDataType, valueNullable)) case t if t <:< localTypeOf[String] => @@ -558,6 +469,7 @@ object ScalaReflection extends ScalaReflection { DateType, "fromJavaDate", inputObject :: Nil) + case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( Decimal, @@ -587,26 +499,24 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if t <:< definitions.IntTpe => - BoundReference(0, IntegerType, false) - case t if t <:< definitions.LongTpe => - BoundReference(0, LongType, false) - case t if t <:< definitions.DoubleTpe => - BoundReference(0, DoubleType, false) - case t if t <:< definitions.FloatTpe => - BoundReference(0, FloatType, false) - case t if t <:< definitions.ShortTpe => - BoundReference(0, ShortType, false) - case t if t <:< definitions.ByteTpe => - BoundReference(0, ByteType, false) - case t if t <:< definitions.BooleanTpe => - BoundReference(0, BooleanType, false) - case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } } } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } } /** @@ -635,8 +545,7 @@ trait ScalaReflection { } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: Schema = - ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) } + def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** * Return the Scala Type for `T` in the current classloader mirror. @@ -736,39 +645,4 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } - - def typeOfObject: PartialFunction[Any, DataType] = { - // The data type can be determined without ambiguity. - case obj: Boolean => BooleanType - case obj: Array[Byte] => BinaryType - case obj: String => StringType - case obj: UTF8String => StringType - case obj: Byte => ByteType - case obj: Short => ShortType - case obj: Int => IntegerType - case obj: Long => LongType - case obj: Float => FloatType - case obj: Double => DoubleType - case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case obj: Decimal => DecimalType.SYSTEM_DEFAULT - case obj: java.sql.Timestamp => TimestampType - case null => NullType - // For other cases, there is no obvious mapping from the type of the given object to a - // Catalyst data type. A user should provide his/her specific rules - // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of - // objects and then compose the user-defined PartialFunction with this one. - } - - implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { - - /** - * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation - * for the the data in the sequence. - */ - def asRelation: LocalRelation = { - val output = attributesFor[A] - LocalRelation.fromProduct(output, data) - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 456b59500847..6eeba1442c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** - * A factory for constructing encoders that convert objects and primitves to and from the + * A factory for constructing encoders that convert objects and primitives to and from the * internal row format using catalyst expressions and code generation. By default, the * expressions used to retrieve values from an input row when producing an object will be created as * follows: @@ -44,20 +44,21 @@ import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType * to the name `value`. */ object ExpressionEncoder { - def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) + val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) + val fromRowExpression = ScalaReflection.constructorFor[T] new ExpressionEncoder[T]( - extractExpression.dataType, + toRowExpression.dataType, flat, - extractExpression.flatten, - constructExpression, + toRowExpression.flatten, + fromRowExpression, ClassTag[T](cls)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala deleted file mode 100644 index 6d307ab13a9f..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} -import org.apache.spark.sql.catalyst.ScalaReflection - -object FlatEncoder { - import ScalaReflection.schemaFor - import ScalaReflection.dataTypeFor - - def apply[T : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) - - val input = BoundReference(0, dataTypeFor(tpe), nullable = true) - val toRowExpression = CreateNamedStruct( - Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) - val fromRowExpression = ProductEncoder.constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = true, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 2914c6ee790c..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import org.apache.spark.util.Utils -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} - -import scala.reflect.ClassTag - -object ProductEncoder { - import ScalaReflection.universe._ - import ScalaReflection.mirror - import ScalaReflection.localTypeOf - import ScalaReflection.dataTypeFor - import ScalaReflection.Schema - import ScalaReflection.schemaFor - import ScalaReflection.arrayClassFor - - def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val tpe = typeTag[T].tpe - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] - val fromRowExpression = constructorFor(tpe) - - new ExpressionEncoder[T]( - toRowExpression.dataType, - flat = false, - toRowExpression.flatten, - fromRowExpression, - ClassTag[T](cls)) - } - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map - - def extractorFor( - inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName - val classObj = Utils.classForName(className) - val optionObjectType = ObjectType(classObj) - - val unwrapped = UnwrapOption(optionObjectType, inputObject) - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) - } - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil - }) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case other => - throw new UnsupportedOperationException(s"Encoder for type $other is not supported") - } - } - } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (RowEncoder.isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } - - def constructorFor( - tpe: `Type`, - path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { - - /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String): Expression = path - .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) - - /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) - .getOrElse(BoundReference(ordinal, dataType, false)) - - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) - - tpe match { - case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath - - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - WrapOption(null, constructorFor(optType, path)) - - case t if t <:< localTypeOf[java.lang.Integer] => - val boxedType = classOf[java.lang.Integer] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Long] => - val boxedType = classOf[java.lang.Long] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Double] => - val boxedType = classOf[java.lang.Double] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Float] => - val boxedType = classOf[java.lang.Float] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Short] => - val boxedType = classOf[java.lang.Short] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Byte] => - val boxedType = classOf[java.lang.Byte] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.lang.Boolean] => - val boxedType = classOf[java.lang.Boolean] - val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Date]), - "toJavaDate", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils, - ObjectType(classOf[java.sql.Timestamp]), - "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) - - case t if t <:< localTypeOf[java.lang.String] => - Invoke(getPath, "toString", ObjectType(classOf[String])) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) - - case t if t <:< localTypeOf[BigDecimal] => - Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None - } - - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) - } - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val arrayData = - Invoke( - MapObjects( - p => constructorFor(elementType, Some(p)), - getPath, - schemaFor(elementType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - scala.collection.mutable.WrappedArray, - ObjectType(classOf[Seq[_]]), - "make", - arrayData :: Nil) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keyData = - Invoke( - MapObjects( - p => constructorFor(keyType, Some(p)), - Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), - schemaFor(keyType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - MapObjects( - p => constructorFor(valueType, Some(p)), - Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), - schemaFor(valueType).dataType), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData, - ObjectType(classOf[Map[_, _]]), - "toScalaMap", - keyData :: valueData :: Nil) - - case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) - - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType - - // For tuples, we based grab the inner fields by ordinal instead of name. - if (cls.getName startsWith "scala.Tuple") { - constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) - } else { - constructorFor(fieldType, Some(addToPath(fieldName))) - } - } - - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) - - if (path.nonEmpty) { - expressions.If( - IsNull(getPath), - expressions.Literal.create(null, ObjectType(cls)), - newInstance - ) - } else { - newInstance - } - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 9bb1602494b6..4cda4824acdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -132,17 +133,8 @@ object RowEncoder { CreateStruct(convertedFields) } - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true - case _ => false - } - private def externalDataTypeFor(dt: DataType): DataType = dt match { - case _ if isNativeType(dt) => dt + case _ if ScalaReflection.isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f865a9408ef4..ef7399e0196a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,6 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.encoders.ProductEncoder import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.InternalRow @@ -300,10 +299,9 @@ case class UnwrapOption( /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. - * @param optionType The datatype to be held inside of the Option. * @param child The expression to evaluate and wrap. */ -case class WrapOption(optionType: DataType, child: Expression) +case class WrapOption(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = ObjectType(classOf[Option[_]]) @@ -316,14 +314,13 @@ case class WrapOption(optionType: DataType, child: Expression) throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val javaType = ctx.javaType(optionType) val inputObject = child.gen(ctx) s""" ${inputObject.code} boolean ${ev.isNull} = false; - scala.Option<$javaType> ${ev.value} = + scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 4ea410d492b0..c2aace1ef238 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } - test("get data type of a value") { - // BooleanType - assert(BooleanType === typeOfObject(true)) - assert(BooleanType === typeOfObject(false)) - - // BinaryType - assert(BinaryType === typeOfObject("string".getBytes)) - - // StringType - assert(StringType === typeOfObject("string")) - - // ByteType - assert(ByteType === typeOfObject(127.toByte)) - - // ShortType - assert(ShortType === typeOfObject(32767.toShort)) - - // IntegerType - assert(IntegerType === typeOfObject(2147483647)) - - // LongType - assert(LongType === typeOfObject(9223372036854775807L)) - - // FloatType - assert(FloatType === typeOfObject(3.4028235E38.toFloat)) - - // DoubleType - assert(DoubleType === typeOfObject(1.7976931348623157E308)) - - // DecimalType - assert(DecimalType.SYSTEM_DEFAULT === - typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) - - // DateType - assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) - - // TimestampType - assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00"))) - - // NullType - assert(NullType === typeOfObject(null)) - - def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT - case _ => StringType - } - - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new BigInteger("92233720368547758070"))) - assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( - new java.math.BigDecimal("1.7976931348623157E318"))) - assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) - - def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT - } - - intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) - - def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { - case c: Seq[_] => ArrayType(typeOfObject3(c.head)) - } - - assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) - } - test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index cde0364f3dd9..76459b34a484 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,24 +17,234 @@ package org.apache.spark.sql.catalyst.encoders +import java.sql.{Timestamp, Date} import java.util.Arrays import java.util.concurrent.ConcurrentMap +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.types.ArrayType -abstract class ExpressionEncoderSuite extends SparkFunSuite { - val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() +case class RepeatedStruct(s: Seq[PrimitiveData]) - protected def encodeDecodeTest[T]( +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +/** For testing Kryo serialization based encoder. */ +class KryoSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[KryoSerializable].value + } +} + +/** For testing Java serialization based encoder. */ +class JavaSerializable(val value: Int) extends Serializable { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[JavaSerializable].value + } +} + +class ExpressionEncoderSuite extends SparkFunSuite { + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() + + // test flat encoders + encodeDecodeTest(false, "primitive boolean") + encodeDecodeTest(-3.toByte, "primitive byte") + encodeDecodeTest(-3.toShort, "primitive short") + encodeDecodeTest(-3, "primitive int") + encodeDecodeTest(-3L, "primitive long") + encodeDecodeTest(-3.7f, "primitive float") + encodeDecodeTest(-3.7, "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") + // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + + encodeDecodeTest("hello", "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), "binary") + + encodeDecodeTest(Seq(31, -123, 4), "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null") + encodeDecodeTest(Seq.empty[Int], "empty seq of int") + encodeDecodeTest(Seq.empty[String], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), "array of int") + encodeDecodeTest(Array("abc", "xyz"), "array of string") + encodeDecodeTest(Array("a", null, "x"), "array of string with null") + encodeDecodeTest(Array.empty[Int], "empty array of int") + encodeDecodeTest(Array.empty[String], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + + // Kryo encoders + encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) + encodeDecodeTest(new KryoSerializable(15), "kryo object")( + encoderFor(Encoders.kryo[KryoSerializable])) + + // Java encoders + encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String])) + encodeDecodeTest(new JavaSerializable(15), "java object")( + encoderFor(Encoders.javaSerialization[JavaSerializable])) + + // test product encoders + private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = { + encodeDecodeTest(input, input.getClass.getSimpleName) + } + + case class InnerClass(i: Int) + productTest(InnerClass(1)) + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + // test for ExpressionEncoder.tuple + encodeDecodeTest( + 1 -> 10L, + "tuple with 2 flat encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + "tuple with 2 product encoders")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)])) + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + "tuple with flat encoder and product encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int])) + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + "tuple with product encoder and flat encoder")( + ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData])) + + encodeDecodeTest( + (1, (10, 100L)), + "nested tuple encoder") { + val intEnc = ExpressionEncoder[Int] + val longEnc = ExpressionEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + } + + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() + outers.put(getClass.getName, this) + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, - encoder: ExpressionEncoder[T], testName: String): Unit = { test(s"encode/decode for $testName: $input") { + val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema, outers).bind(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala deleted file mode 100644 index 07523d49f426..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import java.sql.{Date, Timestamp} -import org.apache.spark.sql.Encoders - -class FlatEncoderSuite extends ExpressionEncoderSuite { - encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") - encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") - encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") - encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") - encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") - encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") - encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") - - encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") - encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") - encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") - encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") - encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") - encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") - encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") - - encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") - type JDecimal = java.math.BigDecimal - // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") - - encodeDecodeTest("hello", FlatEncoder[String], "string") - encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") - encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") - encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") - - encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") - encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") - encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") - encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") - encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") - - encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), - FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") - encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), - FlatEncoder[Seq[Seq[String]]], "seq of seq of string") - - encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") - encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") - encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") - encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") - encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") - - encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), - FlatEncoder[Array[Array[Int]]], "array of array of int") - encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), - FlatEncoder[Array[Array[String]]], "array of array of string") - - encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") - encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") - encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), - FlatEncoder[Map[Int, Map[String, Int]]], "map of map") - - // Kryo encoders - encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string") - encodeDecodeTest(new KryoSerializable(15), - encoderFor(Encoders.kryo[KryoSerializable]), "kryo object") - - // Java encoders - encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string") - encodeDecodeTest(new JavaSerializable(15), - encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object") -} - -/** For testing Kryo serialization based encoder. */ -class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value - } -} - -/** For testing Java serialization based encoder. */ -class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala deleted file mode 100644 index 1798514c5c38..000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) { - override def equals(other: Any): Boolean = other match { - case NestedArray(otherArray) => - java.util.Arrays.deepEquals( - a.asInstanceOf[Array[AnyRef]], - otherArray.asInstanceOf[Array[AnyRef]]) - case _ => false - } -} - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ProductEncoderSuite extends ExpressionEncoderSuite { - outers.put(getClass.getName, this) - - case class InnerClass(i: Int) - productTest(InnerClass(1)) - - productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - productTest( - OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - productTest(OptionalData(None, None, None, None, None, None, None, None)) - - productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - productTest(BoxedData(null, null, null, null, null, null, null)) - - productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) - - productTest(("Seq[(String, String)]", - Seq(("a", "b")))) - productTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - productTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - productTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - productTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - productTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - productTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - productTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - productTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - productTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - productTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - productTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - productTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - productTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - productTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - productTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - productTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTest( - 1 -> 10L, - ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), - "tuple with 2 flat encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), - "tuple with 2 product encoders") - - encodeDecodeTest( - (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), - ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), - "tuple with flat encoder and product encoder") - - encodeDecodeTest( - (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), - ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), - "tuple with product encoder and flat encoder") - - encodeDecodeTest( - (1, (10, 100L)), - { - val intEnc = FlatEncoder[Int] - val longEnc = FlatEncoder[Long] - ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) - }, - "nested tuple encoder") - - private def productTest[T <: Product : TypeTag](input: T): Unit = { - encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7e5acbe8517d..6de3dd626576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -242,7 +242,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 8471eea1b7d9..25ffdcde1771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -28,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.StructField import org.apache.spark.unsafe.types.UTF8String @@ -37,16 +34,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() - implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] - implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] - implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] - implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] - implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] - implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] - implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] - implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** * Creates a [[Dataset]] from an RDD. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 95158de710ac..b27b1340cce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.FlatEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(FlatEncoder[Long]) + count(Column(columnName)).as(ExpressionEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. From 4700074530d9a398843e13f0ef514be97a237cea Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 19 Nov 2015 13:08:01 -0800 Subject: [PATCH 802/896] [SPARK-11778][SQL] parse table name before it is passed to lookupRelation Fix a bug in DataFrameReader.table (table with schema name such as "db_name.table" doesn't work) Use SqlParser.parseTableIdentifier to parse the table name before lookupRelation. Author: Huaxin Gao Closes #9773 from huaxingao/spark-11778. --- .../scala/org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../spark/sql/hive/HiveDataFrameAnalyticsSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 5872fbded383..dcb3737b70fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -313,7 +313,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def table(tableName: String): DataFrame = { - DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName))) + DataFrame(sqlContext, + sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf76526..f19a74d4b372 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -34,10 +34,14 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with override def beforeAll() { testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") + hiveContext.sql("create schema usrdb") + hiveContext.sql("create table usrdb.test(c1 int)") } override def afterAll(): Unit = { hiveContext.dropTempTable("mytable") + hiveContext.sql("drop table usrdb.test") + hiveContext.sql("drop schema usrdb") } test("rollup") { @@ -74,4 +78,10 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with sql("select a, b, sum(b) from mytable group by a, b with cube").collect() ) } + + // There was a bug in DataFrameFrameReader.table and it has problem for table with schema name, + // Before fix, it throw Exceptionorg.apache.spark.sql.catalyst.analysis.NoSuchTableException + test("table name with schema") { + hiveContext.read.table("usrdb.test") + } } From 599a8c6e2bf7da70b20ef3046f5ce099dfd637f8 Mon Sep 17 00:00:00 2001 From: David Tolpin Date: Thu, 19 Nov 2015 13:57:23 -0800 Subject: [PATCH 803/896] [SPARK-11812][PYSPARK] invFunc=None works properly with python's reduceByKeyAndWindow invFunc is optional and can be None. Instead of invFunc (the parameter) invReduceFunc (a local function) was checked for trueness (that is, not None, in this context). A local function is never None, thus the case of invFunc=None (a common one when inverse reduction is not defined) was treated incorrectly, resulting in loss of data. In addition, the docstring used wrong parameter names, also fixed. Author: David Tolpin Closes #9775 from dtolpin/master. --- python/pyspark/streaming/dstream.py | 6 +++--- python/pyspark/streaming/tests.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 698336cfce18..acec850f02c2 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -524,8 +524,8 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None `invFunc` can be None, then it will reduce all the RDDs in window, could be slower than having `invFunc`. - @param reduceFunc: associative reduce function - @param invReduceFunc: inverse function of `reduceFunc` + @param func: associative reduce function + @param invFunc: inverse function of `reduceFunc` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -556,7 +556,7 @@ def invReduceFunc(t, a, b): if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) - if invReduceFunc: + if invFunc: jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) else: jinvReduceFunc = None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0bcd1f15532b..3403f6d20d78 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -582,6 +582,17 @@ def test_reduce_by_invalid_window(self): self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + def test_reduce_by_key_and_window_with_none_invFunc(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.map(lambda x: (x, 1))\ + .reduceByKeyAndWindow(operator.add, None, 5, 1)\ + .filter(lambda kv: kv[1] > 0).count() + + expected = [[2], [4], [6], [6], [6], [6]] + self._test_func(input, func, expected) + class StreamingContextTests(PySparkStreamingTestCase): From 014c0f7a9dfdb1686fa9aeacaadb2a17a855a943 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Nov 2015 14:48:18 -0800 Subject: [PATCH 804/896] [SPARK-11858][SQL] Move sql.columnar into sql.execution. In addition, tightened visibility of a lot of classes in the columnar package from private[sql] to private[columnar]. Author: Reynold Xin Closes #9842 from rxin/SPARK-11858. --- .../spark/sql/execution/CacheManager.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../columnar/ColumnAccessor.scala | 42 +++++++-------- .../columnar/ColumnBuilder.scala | 51 ++++++++++--------- .../columnar/ColumnStats.scala | 34 ++++++------- .../{ => execution}/columnar/ColumnType.scala | 48 ++++++++--------- .../columnar/GenerateColumnAccessor.scala | 4 +- .../columnar/InMemoryColumnarTableScan.scala | 5 +- .../columnar/NullableColumnAccessor.scala | 4 +- .../columnar/NullableColumnBuilder.scala | 4 +- .../CompressibleColumnAccessor.scala | 6 +-- .../CompressibleColumnBuilder.scala | 6 +-- .../compression/CompressionScheme.scala | 16 +++--- .../compression/compressionSchemes.scala | 16 +++--- .../apache/spark/sql/execution/package.scala | 2 + .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/QueryTest.scala | 2 +- .../columnar/ColumnStatsSuite.scala | 6 +-- .../columnar/ColumnTypeSuite.scala | 4 +- .../columnar/ColumnarTestUtils.scala | 2 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../NullableColumnAccessorSuite.scala | 4 +- .../columnar/NullableColumnBuilderSuite.scala | 4 +- .../columnar/PartitionBatchPruningSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 6 +-- .../compression/DictionaryEncodingSuite.scala | 6 +-- .../compression/IntegralDeltaSuite.scala | 6 +-- .../compression/RunLengthEncodingSuite.scala | 6 +-- .../TestCompressibleColumnBuilder.scala | 4 +- .../spark/sql/hive/CachedTableSuite.scala | 2 +- 30 files changed, 155 insertions(+), 147 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnAccessor.scala (75%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnBuilder.scala (74%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStats.scala (88%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnType.scala (93%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/GenerateColumnAccessor.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarTableScan.scala (98%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessor.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilder.scala (95%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnAccessor.scala (84%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressibleColumnBuilder.scala (94%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/CompressionScheme.scala (83%) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/columnar/compression/compressionSchemes.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnStatsSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnTypeSuite.scala (97%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/ColumnarTestUtils.scala (98%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/InMemoryColumnarQuerySuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnAccessorSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/NullableColumnBuilderSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/PartitionBatchPruningSuite.scala (99%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/BooleanBitSetSuite.scala (94%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/DictionaryEncodingSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/IntegralDeltaSuite.scala (96%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/RunLengthEncodingSuite.scala (95%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => execution}/columnar/compression/TestCompressibleColumnBuilder.scala (93%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f85aeb1b0269..293fcfe96e67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3d4ce633c07c..f67c951bc066 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.{Strategy, execution} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala similarity index 75% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 42ec4d3433f1..fee36f602389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} -import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ /** @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods * for primitive values provided by [[MutableRow]]. */ -private[sql] trait ColumnAccessor { +private[columnar] trait ColumnAccessor { initialize() protected def initialize() @@ -41,7 +41,7 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[JvmType]( +private[columnar] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { @@ -61,65 +61,65 @@ private[sql] abstract class BasicColumnAccessor[JvmType]( protected def underlyingBuffer = buffer } -private[sql] class NullColumnAccessor(buffer: ByteBuffer) +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Any](buffer, NULL) with NullableColumnAccessor -private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor with CompressibleColumnAccessor[T] -private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BYTE) -private[sql] class ShortColumnAccessor(buffer: ByteBuffer) +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, INT) -private[sql] class LongColumnAccessor(buffer: ByteBuffer) +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class FloatColumnAccessor(buffer: ByteBuffer) +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, DOUBLE) -private[sql] class StringColumnAccessor(buffer: ByteBuffer) +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor -private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) with NullableColumnAccessor -private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor -private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor -private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[sql] object ColumnAccessor { +private[columnar] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala similarity index 74% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 599f30f2d73b..7e26f19bb744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} import org.apache.spark.sql.types._ -private[sql] trait ColumnBuilder { +private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ @@ -46,7 +46,7 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[JvmType]( +private[columnar] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, val columnType: ColumnType[JvmType]) extends ColumnBuilder { @@ -84,17 +84,17 @@ private[sql] class BasicColumnBuilder[JvmType]( } } -private[sql] class NullColumnBuilder +private[columnar] class NullColumnBuilder extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) with NullableColumnBuilder -private[sql] abstract class ComplexColumnBuilder[JvmType]( +private[columnar] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, columnType: ColumnType[JvmType]) extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T#InternalType](columnStats, columnType) @@ -102,40 +102,45 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) -private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType) +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) -private[sql] class DecimalColumnBuilder(dataType: DecimalType) +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) -private[sql] class StructColumnBuilder(dataType: StructType) +private[columnar] class StructColumnBuilder(dataType: StructType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) -private[sql] class ArrayColumnBuilder(dataType: ArrayType) +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) -private[sql] class MapColumnBuilder(dataType: MapType) +private[columnar] class MapColumnBuilder(dataType: MapType) extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) -private[sql] object ColumnBuilder { +private[columnar] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 91a05650585c..c52ee9ffd6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() @@ -32,7 +32,7 @@ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } -private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { val (forAttribute, schema) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) @@ -45,10 +45,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed trait ColumnStats extends Serializable { +private[columnar] sealed trait ColumnStats extends Serializable { protected var count = 0 protected var nullCount = 0 - private[sql] var sizeInBytes = 0L + private[columnar] var sizeInBytes = 0L /** * Gathers statistics information from `row(ordinal)`. @@ -72,14 +72,14 @@ private[sql] sealed trait ColumnStats extends Serializable { /** * A no-op ColumnStats only used for testing purposes. */ -private[sql] class NoopColumnStats extends ColumnStats { +private[columnar] class NoopColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) override def collectedStatistics: GenericInternalRow = new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) } -private[sql] class BooleanColumnStats extends ColumnStats { +private[columnar] class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true @@ -97,7 +97,7 @@ private[sql] class BooleanColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ByteColumnStats extends ColumnStats { +private[columnar] class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue @@ -115,7 +115,7 @@ private[sql] class ByteColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ShortColumnStats extends ColumnStats { +private[columnar] class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue @@ -133,7 +133,7 @@ private[sql] class ShortColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class IntColumnStats extends ColumnStats { +private[columnar] class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue @@ -151,7 +151,7 @@ private[sql] class IntColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class LongColumnStats extends ColumnStats { +private[columnar] class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue @@ -169,7 +169,7 @@ private[sql] class LongColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class FloatColumnStats extends ColumnStats { +private[columnar] class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue @@ -187,7 +187,7 @@ private[sql] class FloatColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class DoubleColumnStats extends ColumnStats { +private[columnar] class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue @@ -205,7 +205,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class StringColumnStats extends ColumnStats { +private[columnar] class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null @@ -223,7 +223,7 @@ private[sql] class StringColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class BinaryColumnStats extends ColumnStats { +private[columnar] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { @@ -235,7 +235,7 @@ private[sql] class BinaryColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null @@ -256,7 +256,7 @@ private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends Column new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 68e509eb5047..c9f2329db4b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.types.UTF8String * * WARNNING: This only works with HeapByteBuffer */ -object ByteBufferHelper { +private[columnar] object ByteBufferHelper { def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -73,7 +73,7 @@ object ByteBufferHelper { * * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[JvmType] { +private[columnar] sealed abstract class ColumnType[JvmType] { // The catalyst data type of this column. def dataType: DataType @@ -142,7 +142,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] object NULL extends ColumnType[Any] { +private[columnar] object NULL extends ColumnType[Any] { override def dataType: DataType = NullType override def defaultSize: Int = 0 @@ -152,7 +152,7 @@ private[sql] object NULL extends ColumnType[Any] { override def getField(row: InternalRow, ordinal: Int): Any = null } -private[sql] abstract class NativeColumnType[T <: AtomicType]( +private[columnar] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) extends ColumnType[T#InternalType] { @@ -163,7 +163,7 @@ private[sql] abstract class NativeColumnType[T <: AtomicType]( def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } -private[sql] object INT extends NativeColumnType(IntegerType, 4) { +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { override def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } @@ -192,7 +192,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { } } -private[sql] object LONG extends NativeColumnType(LongType, 8) { +private[columnar] object LONG extends NativeColumnType(LongType, 8) { override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } @@ -220,7 +220,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) { } } -private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } @@ -248,7 +248,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) { } } -private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } @@ -276,7 +276,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) { } } -private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { override def append(v: Boolean, buffer: ByteBuffer): Unit = { buffer.put(if (v) 1: Byte else 0: Byte) } @@ -302,7 +302,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 1) { } } -private[sql] object BYTE extends NativeColumnType(ByteType, 1) { +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } @@ -330,7 +330,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 1) { } } -private[sql] object SHORT extends NativeColumnType(ShortType, 2) { +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } @@ -362,7 +362,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper * objects. */ -private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { @@ -387,7 +387,7 @@ private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { } } -private[sql] object STRING +private[columnar] object STRING extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { @@ -425,7 +425,7 @@ private[sql] object STRING override def clone(v: UTF8String): UTF8String = v.clone() } -private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { @@ -467,13 +467,13 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) } } -private[sql] object COMPACT_DECIMAL { +private[columnar] object COMPACT_DECIMAL { def apply(dt: DecimalType): COMPACT_DECIMAL = { COMPACT_DECIMAL(dt.precision, dt.scale) } } -private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] @@ -492,7 +492,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: } } -private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def dataType: DataType = BinaryType @@ -512,7 +512,7 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } -private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) extends ByteArrayColumnType[Decimal](12) { override val dataType: DataType = DecimalType(precision, scale) @@ -539,13 +539,13 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) } } -private[sql] object LARGE_DECIMAL { +private[columnar] object LARGE_DECIMAL { def apply(dt: DecimalType): LARGE_DECIMAL = { LARGE_DECIMAL(dt.precision, dt.scale) } } -private[sql] case class STRUCT(dataType: StructType) +private[columnar] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -586,7 +586,7 @@ private[sql] case class STRUCT(dataType: StructType) override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) +private[columnar] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -625,7 +625,7 @@ private[sql] case class ARRAY(dataType: ArrayType) override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) +private[columnar] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -663,7 +663,7 @@ private[sql] case class MAP(dataType: MapType) override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } -private[sql] object ColumnType { +private[columnar] object ColumnType { def apply(dataType: DataType): ColumnType[_] = { dataType match { case NullType => NULL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ff9393b465b7..eaafc96e4d2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -121,7 +121,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; - import org.apache.spark.sql.columnar.MutableUnsafeRow; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index ae77298e6da2..ce701fb3a7f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer @@ -50,7 +50,8 @@ private[sql] object InMemoryRelation { * @param buffers The buffers for serialized columns * @param stats The stat of columns */ -private[sql] case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) private[sql] case class InMemoryRelation( output: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 7eaecfe047c3..8d99546924de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.expressions.MutableRow -private[sql] trait NullableColumnAccessor extends ColumnAccessor { +private[columnar] trait NullableColumnAccessor extends ColumnAccessor { private var nullsBuffer: ByteBuffer = _ private var nullCount: Int = _ private var seenNulls: Int = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala index 76cfddf1cd01..3a1931bfb5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilder.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteBuffer, ByteOrder} @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow * +---+-----+---------+ * }}} */ -private[sql] trait NullableColumnBuilder extends ColumnBuilder { +private[columnar] trait NullableColumnBuilder extends ColumnBuilder { protected var nulls: ByteBuffer = _ protected var nullCount: Int = _ private var pos: Int = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala similarity index 84% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index cb205defbb1a..6579b5068e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { +private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 161021ff9615..b0e216feb559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} +import org.apache.spark.sql.execution.columnar.{ColumnBuilder, NativeColumnBuilder} import org.apache.spark.sql.types.AtomicType /** @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.AtomicType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: AtomicType] +private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 9322b772fd89..920381f9c63d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: AtomicType] { +private[columnar] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} def compressedSize: Int @@ -37,13 +37,13 @@ private[sql] trait Encoder[T <: AtomicType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: AtomicType] { +private[columnar] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean } -private[sql] trait CompressionScheme { +private[columnar] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_]): Boolean @@ -53,15 +53,15 @@ private[sql] trait CompressionScheme { def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } -private[sql] trait WithCompressionSchemes { +private[columnar] trait WithCompressionSchemes { def schemes: Seq[CompressionScheme] } -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { +private[columnar] trait AllCompressionSchemes extends WithCompressionSchemes { override val schemes: Seq[CompressionScheme] = CompressionScheme.all } -private[sql] object CompressionScheme { +private[columnar] object CompressionScheme { val all: Seq[CompressionScheme] = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 41c9a284e3e4..941f03b745a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer @@ -23,11 +23,11 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types._ -private[sql] case object PassThrough extends CompressionScheme { +private[columnar] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_]): Boolean = true @@ -64,7 +64,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } -private[sql] case object RunLengthEncoding extends CompressionScheme { +private[columnar] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { @@ -172,7 +172,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } -private[sql] case object DictionaryEncoding extends CompressionScheme { +private[columnar] case object DictionaryEncoding extends CompressionScheme { override val typeId = 2 // 32K unique values allowed @@ -281,7 +281,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } -private[sql] case object BooleanBitSet extends CompressionScheme { +private[columnar] case object BooleanBitSet extends CompressionScheme { override val typeId = 3 val BITS_PER_LONG = 64 @@ -371,7 +371,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { } } -private[sql] case object IntDelta extends CompressionScheme { +private[columnar] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) @@ -451,7 +451,7 @@ private[sql] case object IntDelta extends CompressionScheme { } } -private[sql] case object LongDelta extends CompressionScheme { +private[columnar] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 28fa231e722d..c912734bba9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -19,5 +19,7 @@ package org.apache.spark.sql /** * The physical execution component of Spark SQL. Note that this is a private package. + * All classes in catalyst are considered an internal API to Spark SQL and are subject + * to change between minor releases. */ package object execution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index bce94dafad75..d86df4cfb9b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -27,7 +27,7 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -280,7 +280,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext sql("CACHE TABLE testData") sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => - val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5417b195f39..6ea1fe4ccfd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index 89a664001bdd..b2d04f7c5a6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -50,7 +50,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -86,7 +86,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 63bc39bfa030..34dd96929e6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala similarity index 98% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index a5882f7870e3..9cae65ef6f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import scala.collection.immutable.HashSet import scala.util.Random diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 6265e40a0a07..25afed25c897 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.sql.{Date, Timestamp} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index aa1605fee8c7..35dc9a276cef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import java.nio.ByteBuffer @@ -38,7 +38,7 @@ object TestNullableColumnAccessor { } class NullableColumnAccessorSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 91404577832a..93be3e16a5ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -36,7 +36,7 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import org.apache.spark.sql.columnar.ColumnarTestUtils._ + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 6b7401464f46..d762f7bfe914 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar +package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 9a2948c59ba4..ccbddef0fad3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index acfab6586c0d..830ca0294e1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 2111e9fbe62c..988a577a7b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala similarity index 95% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 67ec08f594a4..ce3affba55c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5268dfe0aa03..5e078f251375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.columnar.compression +package org.apache.spark.sql.execution.columnar.compression -import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.types.AtomicType class TestCompressibleColumnBuilder[T <: AtomicType]( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 5c2fc7d82ffb..99478e82d419 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} From 90d384dcbc1d1a3466cf8bae570a26f23012c102 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 19 Nov 2015 14:49:25 -0800 Subject: [PATCH 805/896] [SPARK-11831][CORE][TESTS] Use port 0 to avoid port conflicts in tests Use port 0 to fix port-contention-related flakiness Author: Shixiong Zhu Closes #9841 from zsxwing/SPARK-11831. --- .../org/apache/spark/rpc/RpcEnvSuite.scala | 24 +++++++++---------- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 834e4743df86..2f55006420ce 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -39,7 +39,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() - env = createRpcEnv(conf, "local", 12345) + env = createRpcEnv(conf, "local", 0) } override def afterAll(): Unit = { @@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -497,7 +497,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") @@ -543,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") @@ -571,8 +571,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { @volatile var message: String = null @@ -602,8 +602,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) + val localEnv = createRpcEnv(conf, "authentication-local", 0) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 6478ab51c4da..7aac02775e1b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -40,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) + RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -59,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) + RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) From 3bd77b213a9cd177c3ea3c61d37e5098e55f75a5 Mon Sep 17 00:00:00 2001 From: Srinivasa Reddy Vundela Date: Thu, 19 Nov 2015 14:51:40 -0800 Subject: [PATCH 806/896] =?UTF-8?q?[SPARK-11799][CORE]=20Make=20it=20expli?= =?UTF-8?q?cit=20in=20executor=20logs=20that=20uncaught=20e=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …xceptions are thrown during executor shutdown This commit will make sure that when uncaught exceptions are prepended with [Container in shutdown] when JVM is shutting down. Author: Srinivasa Reddy Vundela Closes #9809 from vundela/master_11799. --- .../apache/spark/util/SparkUncaughtExceptionHandler.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 724818724733..5e322557e964 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -29,7 +29,11 @@ private[spark] object SparkUncaughtExceptionHandler override def uncaughtException(thread: Thread, exception: Throwable) { try { - logError("Uncaught exception in thread " + thread, exception) + // Make it explicit that uncaught exceptions are thrown when container is shutting down. + // It will help users when they analyze the executor logs + val inShutdownMsg = if (ShutdownHookManager.inShutdown()) "[Container in shutdown] " else "" + val errMsg = "Uncaught exception in thread " + logError(inShutdownMsg + errMsg + thread, exception) // We may have been called from a shutdown hook. If so, we must not call System.exit(). // (If we do, we will deadlock.) From f7135ed7194d4f936f6f58e14f02b1ed93f68ad1 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 14:53:58 -0800 Subject: [PATCH 807/896] [SPARK-11828][CORE] Register DAGScheduler metrics source after app id is known. Author: Marcelo Vanzin Closes #9820 from vanzin/SPARK-11828. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ab374cb71286..af4456c05b0a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -581,6 +581,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(_dagScheduler.metricsSource) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a9518fff4e7..ae725b467d8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -130,7 +130,7 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() @@ -1580,8 +1580,6 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread and register the metrics source at the end of the constructor - env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } From 01403aa97b6aaab9b86ae806b5ea9e82690a741f Mon Sep 17 00:00:00 2001 From: hushan Date: Thu, 19 Nov 2015 14:56:00 -0800 Subject: [PATCH 808/896] [SPARK-11746][CORE] Use cache-aware method dependencies a small change Author: hushan Closes #9691 from suyanNone/unify-getDependency. --- .../main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index d6a37e8cc5da..0c6ddda52cee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -65,7 +65,7 @@ class PartitionPruningRDD[T: ClassTag]( } override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + dependencies.head.asInstanceOf[PruneDependency[T]].partitions } From 37cff1b1a79cad11277612cb9bc8bc2365cf5ff2 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 19 Nov 2015 15:11:30 -0800 Subject: [PATCH 809/896] [SPARK-11275][SQL] Incorrect results when using rollup/cube Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result. Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer. Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite: ``` build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite' ``` This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it. Author: Andrew Ray Closes #9815 from aray/groupingset-agg-fix. --- .../sql/catalyst/analysis/Analyzer.scala | 58 +++++++---------- .../plans/logical/basicOperators.scala | 4 ++ .../spark/sql/DataFrameAggregateSuite.scala | 62 +++++++++++++++++++ 3 files changed, 90 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 84781cd57f3d..47962ebe6ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -213,45 +213,35 @@ class Analyzer( GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() - // We will insert another Projection if the GROUP BY keys contains the - // non-attribute expressions. And the top operators can references those - // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a - - // find all of the non-attribute expressions in the GROUP BY keys - val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() - - // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { - case e: NamedExpression => (e, e.toAttribute) - case other => { - val alias = Alias(other, other.toString)() - nonAttributeGroupByExpressions += alias // add the non-attributes expression alias - (other, alias.toAttribute) - } - }) - - // substitute the non-attribute expressions for aggregations. - val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) - }.asInstanceOf[NamedExpression]) - // substitute the group by expressions. - val newGroupByExprs = groupByExprPairs.map(_._2) + // Expand works by setting grouping expressions to null as determined by the bitmasks. To + // prevent these null values from being used in an aggregate instead of the original value + // we need to create new aliases for all group by expressions that will only be used for + // the intended purpose. + val groupByAliases: Seq[Alias] = x.groupByExprs.map { + case e: NamedExpression => Alias(e, e.name)() + case other => Alias(other, other.toString)() + } - val child = if (nonAttributeGroupByExpressions.length > 0) { - // insert additional projection if contains the - // non-attribute expressions in the GROUP BY keys - Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) - } else { - x.child + val aggregations: Seq[NamedExpression] = x.aggregations.map { + // If an expression is an aggregate (contains a AggregateExpression) then we dont change + // it so that the aggregation is computed on the unmodified value of its argument + // expressions. + case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr + // If not then its a grouping expression and we need to use the modified (with nulls from + // Expand) value of the expression. + case expr => expr.transformDown { + case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + }.asInstanceOf[NamedExpression] } + val child = Project(x.child.output ++ groupByAliases, x.child) + val groupByAttributes = groupByAliases.map(_.toAttribute) + Aggregate( - newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + groupByAttributes :+ VirtualColumn.groupingIdAttribute, + aggregations, + Expand(x.bitmasks, groupByAttributes, gid, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 45630a591d34..0c444482c5e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode { override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 71adf2148a40..9c42f65bb6f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("rollup") { + checkAnswer( + courseSales.rollup("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + courseSales.cube("course", "year").sum("earnings"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("rollup overlapping columns") { + checkAnswer( + testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.rollup("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, null, 9) :: Nil + ) + } + + test("cube overlapping columns") { + checkAnswer( + testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) + :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) + :: Row(null, 1, 3) :: Row(null, 2, 0) + :: Row(null, null, 3) :: Nil + ) + + checkAnswer( + testData2.cube("a", "b").agg(sum("b")), + Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) + :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) + :: Row(null, 1, 3) :: Row(null, 2, 6) + :: Row(null, null, 9) :: Nil + ) + } + test("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), From 880128f37e1bc0b9d98d1786670be62a06c648f2 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 16:49:18 -0800 Subject: [PATCH 810/896] [SPARK-4134][CORE] Lower severity of some executor loss logs. Don't log ERROR messages when executors are explicitly killed or when the exit reason is not yet known. Author: Marcelo Vanzin Closes #9780 from vanzin/SPARK-11789. --- .../spark/scheduler/ExecutorLossReason.scala | 2 + .../spark/scheduler/TaskSchedulerImpl.scala | 44 ++++++++++++------- .../spark/scheduler/TaskSetManager.scala | 1 + .../CoarseGrainedSchedulerBackend.scala | 18 +++++--- .../spark/deploy/yarn/YarnAllocator.scala | 4 +- 5 files changed, 45 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 47a5cbff4930..7e1197d74280 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -40,6 +40,8 @@ private[spark] object ExecutorExited { } } +private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed by driver.") + /** * A loss reason that means we don't yet know why the executor exited. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bf0419db1f75..bdf19f9f277d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -470,25 +470,25 @@ private[spark] class TaskSchedulerImpl( synchronized { if (executorIdToTaskCount.contains(executorId)) { val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) + logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { - executorIdToHost.get(executorId) match { - case Some(_) => - // If the host mapping still exists, it means we don't know the loss reason for the - // executor. So call removeExecutor() to update tasks running on that executor when - // the real loss reason is finally known. - logError(s"Actual reason for lost executor $executorId: ${reason.message}") - removeExecutor(executorId, reason) - - case None => - // We may get multiple executorLost() calls with different loss reasons. For example, - // one may be triggered by a dropped connection from the slave while another may be a - // report of executor termination from Mesos. We produce log messages for both so we - // eventually report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } + executorIdToHost.get(executorId) match { + case Some(hostPort) => + // If the host mapping still exists, it means we don't know the loss reason for the + // executor. So call removeExecutor() to update tasks running on that executor when + // the real loss reason is finally known. + logExecutorLoss(executorId, hostPort, reason) + removeExecutor(executorId, reason) + + case None => + // We may get multiple executorLost() calls with different loss reasons. For example, + // one may be triggered by a dropped connection from the slave while another may be a + // report of executor termination from Mesos. We produce log messages for both so we + // eventually report the termination reason. + logError(s"Lost an executor $executorId (already removed): $reason") + } } } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock @@ -498,6 +498,18 @@ private[spark] class TaskSchedulerImpl( } } + private def logExecutorLoss( + executorId: String, + hostPort: String, + reason: ExecutorLossReason): Unit = reason match { + case LossReasonPending => + logDebug(s"Executor $executorId on $hostPort lost, but reason not yet known.") + case ExecutorKilled => + logInfo(s"Executor $executorId on $hostPort killed by driver.") + case _ => + logError(s"Lost executor $executorId on $hostPort: $reason") + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 114468c48c44..a02f3017cb6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -800,6 +800,7 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp + case ExecutorKilled => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 6f0c910c009a..505c161141c8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -64,8 +64,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val listenerBus = scheduler.sc.listenerBus - // Executors we have requested the cluster manager to kill that have not died yet - private val executorsPendingToRemove = new HashSet[String] + // Executors we have requested the cluster manager to kill that have not died yet; maps + // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't + // be considered an app-related failure). + private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it protected var hostToLocalTaskCount: Map[String, Int] = Map.empty @@ -250,15 +252,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case Some(executorInfo) => // This must be synchronized because variables mutated // in this block are read when requesting executors - CoarseGrainedSchedulerBackend.this.synchronized { + val killed = CoarseGrainedSchedulerBackend.this.synchronized { addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId - executorsPendingToRemove -= executorId executorsPendingLossReason -= executorId + executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, reason) + scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") @@ -459,6 +461,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones * @param force whether to force kill busy executors @@ -479,7 +485,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val executorsToKill = knownExecutors .filter { id => !executorsPendingToRemove.contains(id) } .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsPendingToRemove ++= executorsToKill + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } // If we do not wish to replace the executors we kill, sync the target number of executors // with the cluster manager to avoid allocating new ones. When computing the new target, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7e39c3ea56af..73cd9031f025 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -481,7 +481,7 @@ private[yarn] class YarnAllocator( (true, memLimitExceededLogMessage( completedContainer.getDiagnostics, PMEM_EXCEEDED_PATTERN)) - case unknown => + case _ => numExecutorsFailed += 1 (true, "Container marked as failed: " + containerId + onHostStr + ". Exit status: " + completedContainer.getExitStatus + @@ -493,7 +493,7 @@ private[yarn] class YarnAllocator( } else { logInfo(containerExitReason) } - ExecutorExited(0, exitCausedByApp, containerExitReason) + ExecutorExited(exitStatus, exitCausedByApp, containerExitReason) } else { // If we have already released this container, then it must mean // that the driver has explicitly requested it to be killed From b2cecb80ece59a1c086d4ae7aeebef445a4e7299 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 19 Nov 2015 16:50:08 -0800 Subject: [PATCH 811/896] [SPARK-11845][STREAMING][TEST] Added unit test to verify TrackStateRDD is correctly checkpointed To make sure that all lineage is correctly truncated for TrackStateRDD when checkpointed. Author: Tathagata Das Closes #9831 from tdas/SPARK-11845. --- .../org/apache/spark/CheckpointSuite.scala | 411 +++++++++--------- .../streaming/rdd/TrackStateRDDSuite.scala | 60 ++- 2 files changed, 267 insertions(+), 204 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 119e5fc28e41..ab23326c6c25 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,17 +21,223 @@ import java.io.File import scala.reflect.ClassTag +import org.apache.spark.CheckpointSuite._ import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils +trait RDDCheckpointTester { self: SparkFunSuite => + + protected val partitioner = new HashPartitioner(2) + + private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() + + /** Implementations of this trait must implement this method */ + protected def sparkContext: SparkContext + + /** + * Test checkpointing of the RDD generated by the given operation. It tests whether the + * serialized size of the RDD is reduce after checkpointing or not. This function should be called + * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDD[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + val partitionsBeforeCheckpoint = operatedRDD.partitions + + // Find serialized sizes before and after the checkpoint + logInfo("RDD before checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + checkpoint(operatedRDD, reliableCheckpoint) + val result = collectFunc(operatedRDD) + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the checkpoint file has been created + if (reliableCheckpoint) { + assert( + collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) + } + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed from its earlier partitions + assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the RDD has reduced. + logInfo("Size of " + rddType + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * the generated RDD will remember the partitions and therefore potentially the whole lineage. + * This function should be called only those RDD whose partitions refer to parent RDD's + * partitions (i.e., do not call it on simple RDD like MappedRDD). + * + * @param op an operation to run on the RDD + * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints + * @param collectFunc a function for collecting the values in the RDD, in case there are + * non-comparable types like arrays that we want to convert to something + * that supports == + */ + protected def testRDDPartitions[U: ClassTag]( + op: (RDD[Int]) => RDD[U], + reliableCheckpoint: Boolean, + collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { + // Generate the final RDD using given RDD operation + val baseRDD = generateFatRDD() + val operatedRDD = op(baseRDD) + val parentRDDs = operatedRDD.dependencies.map(_.rdd) + val rddType = operatedRDD.getClass.getSimpleName + + // Force initialization of all the data structures in RDDs + // Without this, serializing the RDD will give a wrong estimate of the size of the RDD + initializeRdd(operatedRDD) + + // Find serialized sizes before and after the checkpoint + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + // checkpoint the parent RDD, not the generated one + parentRDDs.foreach { rdd => + checkpoint(rdd, reliableCheckpoint) + } + val result = collectFunc(operatedRDD) // force checkpointing + operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables + val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) + + // Test whether the data in the checkpointed RDD is same as original + assert(collectFunc(operatedRDD) === result) + + // Test whether serialized size of the partitions has reduced + logInfo("Size of partitions of " + rddType + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") + assert( + partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + + " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" + ) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + val rddSize = Utils.serialize(rdd).size + val rddCpDataSize = Utils.serialize(rdd.checkpointData).size + val rddPartitionSize = Utils.serialize(rdd.partitions).size + val rddDependenciesSize = Utils.serialize(rdd.dependencies).size + + // Print detailed size, helps in debugging + logInfo("Serialized sizes of " + rdd + + ": RDD = " + rddSize + + ", RDD checkpoint data = " + rddCpDataSize + + ", RDD partitions = " + rddPartitionSize + + ", RDD dependencies = " + rddDependenciesSize + ) + // this makes sure that serializing the RDD's checkpoint data does not + // serialize the whole RDD as well + assert( + rddSize > rddCpDataSize, + "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + + "whole RDD with checkpoint data (" + rddSize + ")" + ) + (rddSize - rddCpDataSize, rddPartitionSize) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + protected def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } + + /** + * Recursively force the initialization of the all members of an RDD and it parents. + */ + private def initializeRdd(rdd: RDD[_]): Unit = { + rdd.partitions // forces the initialization of the partitions + rdd.dependencies.map(_.rdd).foreach(initializeRdd) + } + + /** Checkpoint the RDD either locally or reliably. */ + protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { + if (reliableCheckpoint) { + rdd.checkpoint() + } else { + rdd.localCheckpoint() + } + } + + /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ + protected def runTest(name: String)(body: Boolean => Unit): Unit = { + test(name + " [reliable checkpoint]")(body(true)) + test(name + " [local checkpoint]")(body(false)) + } + + /** + * Generate an RDD such that both the RDD and its partitions have large size. + */ + protected def generateFatRDD(): RDD[Int] = { + new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x) + } + + /** + * Generate an pair RDD (with partitioner) such that both the RDD and its partitions + * have large size. + */ + protected def generateFatPairRDD(): RDD[(Int, Int)] = { + new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + } +} + /** * Test suite for end-to-end checkpointing functionality. * This tests both reliable checkpoints and local checkpoints. */ -class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging { +class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalSparkContext { private var checkpointDir: File = _ - private val partitioner = new HashPartitioner(2) override def beforeEach(): Unit = { super.beforeEach() @@ -46,6 +252,8 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + runTest("basic checkpointing") { reliableCheckpoint: Boolean => val parCollection = sc.makeRDD(1 to 4) val flatMappedRDD = parCollection.flatMap(x => 1 to x) @@ -250,204 +458,6 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.isCheckpointedAndMaterialized === true) assert(rdd.partitions.size === 0) } - - // Utility test methods - - /** Checkpoint the RDD either locally or reliably. */ - private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = { - if (reliableCheckpoint) { - rdd.checkpoint() - } else { - rdd.localCheckpoint() - } - } - - /** Run a test twice, once for local checkpointing and once for reliable checkpointing. */ - private def runTest(name: String)(body: Boolean => Unit): Unit = { - test(name + " [reliable checkpoint]")(body(true)) - test(name + " [local checkpoint]")(body(false)) - } - - private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect() - - /** - * Test checkpointing of the RDD generated by the given operation. It tests whether the - * serialized size of the RDD is reduce after checkpointing or not. This function should be called - * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDD[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - val partitionsBeforeCheckpoint = operatedRDD.partitions - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - checkpoint(operatedRDD, reliableCheckpoint) - val result = collectFunc(operatedRDD) - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the checkpoint file has been created - if (reliableCheckpoint) { - assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result) - } - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed from its earlier partitions - assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the RDD has reduced. - logInfo("Size of " + rddType + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * the generated RDD will remember the partitions and therefore potentially the whole lineage. - * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). - * - * @param op an operation to run on the RDD - * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints - * @param collectFunc a function for collecting the values in the RDD, in case there are - * non-comparable types like arrays that we want to convert to something that supports == - */ - private def testRDDPartitions[U: ClassTag]( - op: (RDD[Int]) => RDD[U], - reliableCheckpoint: Boolean, - collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = { - // Generate the final RDD using given RDD operation - val baseRDD = generateFatRDD() - val operatedRDD = op(baseRDD) - val parentRDDs = operatedRDD.dependencies.map(_.rdd) - val rddType = operatedRDD.getClass.getSimpleName - - // Force initialization of all the data structures in RDDs - // Without this, serializing the RDD will give a wrong estimate of the size of the RDD - initializeRdd(operatedRDD) - - // Find serialized sizes before and after the checkpoint - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - // checkpoint the parent RDD, not the generated one - parentRDDs.foreach { rdd => - checkpoint(rdd, reliableCheckpoint) - } - val result = collectFunc(operatedRDD) // force checkpointing - operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables - val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString) - - // Test whether the data in the checkpointed RDD is same as original - assert(collectFunc(operatedRDD) === result) - - // Test whether serialized size of the partitions has reduced - logInfo("Size of partitions of " + rddType + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]") - assert( - partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" + - " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]" - ) - } - - /** - * Generate an RDD such that both the RDD and its partitions have large size. - */ - private def generateFatRDD(): RDD[Int] = { - new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x) - } - - /** - * Generate an pair RDD (with partitioner) such that both the RDD and its partitions - * have large size. - */ - private def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - val rddSize = Utils.serialize(rdd).size - val rddCpDataSize = Utils.serialize(rdd.checkpointData).size - val rddPartitionSize = Utils.serialize(rdd.partitions).size - val rddDependenciesSize = Utils.serialize(rdd.dependencies).size - - // Print detailed size, helps in debugging - logInfo("Serialized sizes of " + rdd + - ": RDD = " + rddSize + - ", RDD checkpoint data = " + rddCpDataSize + - ", RDD partitions = " + rddPartitionSize + - ", RDD dependencies = " + rddDependenciesSize - ) - // this makes sure that serializing the RDD's checkpoint data does not - // serialize the whole RDD as well - assert( - rddSize > rddCpDataSize, - "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " + - "whole RDD with checkpoint data (" + rddSize + ")" - ) - (rddSize - rddCpDataSize, rddPartitionSize) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - private def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } - - /** - * Recursively force the initialization of the all members of an RDD and it parents. - */ - private def initializeRdd(rdd: RDD[_]): Unit = { - rdd.partitions // forces the - rdd.dependencies.map(_.rdd).foreach(initializeRdd) - } - } /** RDD partition that has large serialized size. */ @@ -494,5 +504,4 @@ object CheckpointSuite { part ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } - } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 19ef5a14f8ab..0feb3af1abb0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -17,31 +17,40 @@ package org.apache.spark.streaming.rdd +import java.io.File + import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.util.OpenHashMapBasedStateMap -import org.apache.spark.streaming.{Time, State} -import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.streaming.{State, Time} +import org.apache.spark.util.Utils -class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { +class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { private var sc: SparkContext = null + private var checkpointDir: File = _ override def beforeAll(): Unit = { sc = new SparkContext( new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + checkpointDir = Utils.createTempDir() + sc.setCheckpointDir(checkpointDir.toString) } override def afterAll(): Unit = { if (sc != null) { sc.stop() } + Utils.deleteRecursively(checkpointDir) } + override def sparkContext: SparkContext = sc + test("creation from pair RDD") { val data = Seq((1, "1"), (2, "2"), (3, "3")) val partitioner = new HashPartitioner(10) @@ -278,6 +287,51 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { rdd7, Seq(("k3", 2)), Set()) } + test("checkpointing") { + /** + * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs - + * the data RDD and the parent TrackStateRDD. + */ + def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]]) + : Set[(List[(Int, Int, Long)], List[Int])] = { + rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) } + .collect.toSet + } + + /** Generate TrackStateRDD with data RDD having a long lineage */ + def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int]) + : TrackStateRDD[Int, Int, Int, Int] = { + TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0)) + } + + testRDD( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _) + + /** Generate TrackStateRDD with parent state RDD having a long lineage */ + def makeStateRDDWithLongLineageParenttateRDD( + longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = { + + // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage + val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD) + + // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent + new TrackStateRDD[Int, Int, Int, Int]( + stateRDDWithLongLineage, + stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner), + (time: Time, key: Int, value: Option[Int], state: State[Int]) => None, + Time(10), + None + ) + } + + testRDD( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + testRDDPartitions( + makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From ee21407747fb00db2f26d1119446ccbb20c19232 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 19 Nov 2015 17:14:10 -0800 Subject: [PATCH 812/896] [SPARK-11864][SQL] Improve performance of max/min This PR has the following optimization: 1) The greatest/least already does the null-check, so the `If` and `IsNull` are not necessary. 2) In greatest/least, it should initialize the result using the first child (removing one block). 3) For primitive types, the generated greater expression is too complicated (`a > b ? 1 : (a < b) ? -1 : 0) > 0`), should be as simple as `a > b` Combine these optimization, this could improve the performance of `ss_max` query by 30%. Author: Davies Liu Closes #9846 from davies/improve_max. --- .../catalyst/expressions/aggregate/Max.scala | 5 +-- .../catalyst/expressions/aggregate/Min.scala | 5 +-- .../expressions/codegen/CodeGenerator.scala | 12 ++++++ .../expressions/conditionalExpressions.scala | 38 +++++++++++-------- .../expressions/nullExpressions.scala | 10 +++-- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 61cae44cd0f5..906003188d4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -46,13 +46,12 @@ case class Max(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + /* max = */ Greatest(Seq(max, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val greatest = Greatest(Seq(max.left, max.right)) Seq( - /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + /* max = */ Greatest(Seq(max.left, max.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 242456d9e2e1..39f7afbd081c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -47,13 +47,12 @@ case class Min(child: Expression) extends DeclarativeAggregate { ) override lazy val updateExpressions: Seq[Expression] = Seq( - /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + /* min = */ Least(Seq(min, child)) ) override lazy val mergeExpressions: Seq[Expression] = { - val least = Least(Seq(min.left, min.right)) Seq( - /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + /* min = */ Least(Seq(min.left, min.right)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1718cfbd3533..1b7260cdfe51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -329,6 +329,18 @@ class CodeGenContext { throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } + /** + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ + def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { + case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" + case _ => s"(${genComp(dataType, c1, c2)}) > 0" + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 0d4af43978ea..694a2a7c54a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -348,19 +348,22 @@ case class Least(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} < 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, ev.value, eval.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } @@ -403,19 +406,22 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evalChildren = children.map(_.gen(ctx)) - def updateEval(i: Int): String = + val first = evalChildren(0) + val rest = evalChildren.drop(1) + def updateEval(eval: GeneratedExpressionCode): String = s""" - if (!${evalChildren(i).isNull} && (${ev.isNull} || - ${ctx.genComp(dataType, evalChildren(i).value, ev.value)} > 0)) { + ${eval.code} + if (!${eval.isNull} && (${ev.isNull} || + ${ctx.genGreater(dataType, eval.value, ev.value)})) { ${ev.isNull} = false; - ${ev.value} = ${evalChildren(i).value}; + ${ev.value} = ${eval.value}; } """ s""" - ${evalChildren.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${children.indices.map(updateEval).mkString("\n")} + ${first.code} + boolean ${ev.isNull} = ${first.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; + ${rest.map(updateEval).mkString("\n")} """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 94deafb75b69..df4747d4e6f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -62,11 +62,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val first = children(0) + val rest = children.drop(1) + val firstEval = first.gen(ctx) s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${firstEval.code} + boolean ${ev.isNull} = ${firstEval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; """ + - children.map { e => + rest.map { e => val eval = e.gen(ctx) s""" if (${ev.isNull}) { From 7ee7d5a3c4ff77d2cee2afce36ff41f6302e6315 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 19 Nov 2015 19:46:10 -0800 Subject: [PATCH 813/896] [SPARK-11544][SQL][TEST-HADOOP1.0] sqlContext doesn't use PathFilter Apply the user supplied pathfilter while retrieving the files from fs. Author: Dilip Biswal Closes #9830 from dilipbiswal/spark-11544. --- .../apache/spark/sql/sources/interfaces.scala | 25 ++++++++--- .../datasources/json/JsonSuite.scala | 41 ++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3d3bdf50df6..f9465157c936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,7 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{JobConf, FileInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} @@ -447,9 +448,15 @@ abstract class HadoopFsRelation private[sql]( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - logInfo(s"Listing $qualified on driver") - Try(fs.listStatus(qualified)).getOrElse(Array.empty) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(qualified, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(qualified)).getOrElse(Array.empty) + } }.filterNot { status => val name = status.getPath.getName name.toLowerCase == "_temporary" || name.startsWith(".") @@ -847,8 +854,16 @@ private[sql] object HadoopFsRelation extends Logging { if (name == "_temporary" || name.startsWith(".")) { Array.empty } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6042b1178aff..ba7718c86463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,19 +19,27 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.rdd.RDD +import org.apache.commons.io.FileUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, PathFilter} import org.scalactic.Tolerance._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { import testImplicits._ @@ -1390,4 +1398,33 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("SPARK-11544 test pathfilter") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(2) + df.write.json(path + "/p=1") + df.write.json(path + "/p=2") + assert(sqlContext.read.json(path).count() === 4) + + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Setting it twice as the name of the propery has changed between hadoop versions. + hadoopConfiguration.setClass( + "mapred.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + hadoopConfiguration.setClass( + "mapreduce.input.pathFilter.class", + classOf[TestFileFilter], + classOf[PathFilter]) + assert(sqlContext.read.json(path).count() === 2) + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } } From 4114ce20fbe820f111e55e891ae3889b0e6e0006 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 22:01:02 -0800 Subject: [PATCH 814/896] [SPARK-11846] Add save/load for AFTSurvivalRegression and IsotonicRegression https://issues.apache.org/jira/browse/SPARK-11846 mengxr Author: Xusen Yin Closes #9836 from yinxusen/SPARK-11846. --- .../ml/regression/AFTSurvivalRegression.scala | 78 +++++++++++++++-- .../ml/regression/IsotonicRegression.scala | 83 +++++++++++++++++-- .../AFTSurvivalRegressionSuite.scala | 37 ++++++++- .../regression/IsotonicRegressionSuite.scala | 34 +++++++- 4 files changed, 210 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index b7d095872ffa..aedfb48058dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -21,20 +21,20 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} +import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel +import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. @@ -120,7 +120,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params @Experimental @Since("1.6.0") class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) - extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams + with DefaultParamsWritable with Logging { @Since("1.6.0") def this() = this(Identifiable.randomUID("aftSurvReg")) @@ -243,6 +244,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) } +@Since("1.6.0") +object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression] { + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegression = super.load(path) +} + /** * :: Experimental :: * Model produced by [[AFTSurvivalRegression]]. @@ -254,7 +262,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) - extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { /** @group setParam */ @Since("1.6.0") @@ -312,6 +320,58 @@ class AFTSurvivalRegressionModel private[ml] ( copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) .setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = + new AFTSurvivalRegressionModel.AFTSurvivalRegressionModelWriter(this) +} + +@Since("1.6.0") +object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[AFTSurvivalRegressionModel] = new AFTSurvivalRegressionModelReader + + @Since("1.6.0") + override def load(path: String): AFTSurvivalRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[AFTSurvivalRegressionModel]] */ + private[AFTSurvivalRegressionModel] class AFTSurvivalRegressionModelWriter ( + instance: AFTSurvivalRegressionModel + ) extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double, scale: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: coefficients, intercept, scale + val data = Data(instance.coefficients, instance.intercept, instance.scale) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class AFTSurvivalRegressionModelReader extends MLReader[AFTSurvivalRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[AFTSurvivalRegressionModel].getName + + override def load(path: String): AFTSurvivalRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("coefficients", "intercept", "scale").head() + val coefficients = data.getAs[Vector](0) + val intercept = data.getDouble(1) + val scale = data.getDouble(2) + val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a1fe01b04710..bbb1c7ac0a51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -17,18 +17,22 @@ package org.apache.spark.ml.regression +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -127,7 +131,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures @Since("1.5.0") @Experimental class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Estimator[IsotonicRegressionModel] + with IsotonicRegressionBase with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("isoReg")) @@ -179,6 +184,13 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } } +@Since("1.6.0") +object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { + + @Since("1.6.0") + override def load(path: String): IsotonicRegression = super.load(path) +} + /** * :: Experimental :: * Model fitted by IsotonicRegression. @@ -194,7 +206,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri class IsotonicRegressionModel private[ml] ( override val uid: String, private val oldModel: MLlibIsotonicRegressionModel) - extends Model[IsotonicRegressionModel] with IsotonicRegressionBase { + extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable { /** @group setParam */ @Since("1.5.0") @@ -240,4 +252,61 @@ class IsotonicRegressionModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false) } + + @Since("1.6.0") + override def write: MLWriter = + new IsotonicRegressionModelWriter(this) +} + +@Since("1.6.0") +object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { + + @Since("1.6.0") + override def read: MLReader[IsotonicRegressionModel] = new IsotonicRegressionModelReader + + @Since("1.6.0") + override def load(path: String): IsotonicRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[IsotonicRegressionModel]] */ + private[IsotonicRegressionModel] class IsotonicRegressionModelWriter ( + instance: IsotonicRegressionModel + ) extends MLWriter with Logging { + + private case class Data( + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: boundaries, predictions, isotonic + val data = Data( + instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class IsotonicRegressionModelReader extends MLReader[IsotonicRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[IsotonicRegressionModel].getName + + override def load(path: String): IsotonicRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("boundaries", "predictions", "isotonic").head() + val boundaries = data.getAs[Seq[Double]](0).toArray + val predictions = data.getAs[Seq[Double]](1).toArray + val isotonic = data.getBoolean(2) + val model = new IsotonicRegressionModel( + metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 359f31027172..d718ef63b531 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -21,14 +21,15 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row} -class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class AFTSurvivalRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @@ -332,4 +333,32 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(prediction ~== model.predict(features) relTol 1E-5) } } + + test("read/write") { + def checkModelData( + model: AFTSurvivalRegressionModel, + model2: AFTSurvivalRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.scale === model2.scale) + } + val aft = new AFTSurvivalRegression() + testEstimatorAndModelReadWrite(aft, datasetMultivariate, + AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + } +} + +object AFTSurvivalRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "fitIntercept" -> true, + "maxIter" -> 2, + "tol" -> 0.01 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 59f4193abc8f..f067c29d27a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class IsotonicRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { sqlContext.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } @@ -164,4 +166,32 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions === Array(3.5, 5.0, 5.0, 5.0)) } + + test("read/write") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + + def checkModelData(model: IsotonicRegressionModel, model2: IsotonicRegressionModel): Unit = { + assert(model.boundaries === model2.boundaries) + assert(model.predictions === model2.predictions) + assert(model.isotonic === model2.isotonic) + } + + val ir = new IsotonicRegression() + testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, + checkModelData) + } +} + +object IsotonicRegressionSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "isotonic" -> true, + "featureIndex" -> 0 + ) } From 3b7f056da87a23f3a96f0311b3a947a9b698f38b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:02:17 -0800 Subject: [PATCH 815/896] [SPARK-11829][ML] Add read/write to estimators under ml.feature (II) Add read/write support to the following estimators under spark.ml: * ChiSqSelector * PCA * VectorIndexer * Word2Vec Author: Yanbo Liang Closes #9838 from yanboliang/spark-11829. --- .../spark/ml/feature/ChiSqSelector.scala | 65 ++++++++++++++++-- .../org/apache/spark/ml/feature/PCA.scala | 67 +++++++++++++++++-- .../spark/ml/feature/VectorIndexer.scala | 66 ++++++++++++++++-- .../apache/spark/ml/feature/Word2Vec.scala | 67 +++++++++++++++++-- .../apache/spark/mllib/feature/Word2Vec.scala | 6 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 22 +++++- .../apache/spark/ml/feature/PCASuite.scala | 26 ++++++- .../spark/ml/feature/VectorIndexerSuite.scala | 22 +++++- .../spark/ml/feature/Word2VecSuite.scala | 30 ++++++++- 9 files changed, 338 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 5e4061fba549..dfec03828f4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -17,13 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.{AttributeGroup, _} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params */ @Experimental final class ChiSqSelector(override val uid: String) - extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("chiSqSelector")) @@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String) override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) } +@Since("1.6.0") +object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { + + @Since("1.6.0") + override def load(path: String): ChiSqSelector = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[ChiSqSelector]]. @@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String) final class ChiSqSelectorModel private[ml] ( override val uid: String, private val chiSqSelector: feature.ChiSqSelectorModel) - extends Model[ChiSqSelectorModel] with ChiSqSelectorParams { + extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { + + import ChiSqSelectorModel._ + + /** list of indices to select (filter). Must be ordered asc */ + val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures /** @group setParam */ def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] ( val copied = new ChiSqSelectorModel(uid, chiSqSelector) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new ChiSqSelectorModelWriter(this) +} + +@Since("1.6.0") +object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { + + private[ChiSqSelectorModel] + class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter { + + private case class Data(selectedFeatures: Seq[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.selectedFeatures.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] { + + private val className = classOf[ChiSqSelectorModel].getName + + override def load(path: String): ChiSqSelectorModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val selectedFeatures = data.getAs[Seq[Int]](0).toArray + val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) + val model = new ChiSqSelectorModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader + + @Since("1.6.0") + override def load(path: String): ChiSqSelectorModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 539084704b65..32d7afee6e73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -17,13 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * PCA trains a model to project vectors to a low-dimensional space using PCA. */ @Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("pca")) @@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def copy(extra: ParamMap): PCA = defaultCopy(extra) } +@Since("1.6.0") +object PCA extends DefaultParamsReadable[PCA] { + + @Since("1.6.0") + override def load(path: String): PCA = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[PCA]]. @@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams class PCAModel private[ml] ( override val uid: String, pcaModel: feature.PCAModel) - extends Model[PCAModel] with PCAParams { + extends Model[PCAModel] with PCAParams with MLWritable { + + import PCAModel._ + + /** a principal components Matrix. Each column is one principal component. */ + val pc: DenseMatrix = pcaModel.pc /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -127,4 +142,46 @@ class PCAModel private[ml] ( val copied = new PCAModel(uid, pcaModel) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new PCAModelWriter(this) +} + +@Since("1.6.0") +object PCAModel extends MLReadable[PCAModel] { + + private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { + + private case class Data(k: Int, pc: DenseMatrix) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.getK, instance.pc) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class PCAModelReader extends MLReader[PCAModel] { + + private val className = classOf[PCAModel].getName + + override def load(path: String): PCAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("k", "pc") + .head() + val oldModel = new feature.PCAModel(k, pc) + val model = new PCAModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[PCAModel] = new PCAModelReader + + @Since("1.6.0") + override def load(path: String): PCAModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 52e0599e38d8..a637a6f2881d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -22,12 +22,14 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.udf @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ @Experimental class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams { + with VectorIndexerParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("vecIdx")) @@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } -private object VectorIndexer { +@Since("1.6.0") +object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { + + @Since("1.6.0") + override def load(path: String): VectorIndexer = super.load(path) /** * Helper class for tracking unique values for each feature. @@ -146,7 +152,7 @@ private object VectorIndexer { * @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures. * @param maxCategories This class caps the number of unique values collected at maxCategories. */ - class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) + private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int) extends Serializable { /** featureValueSets[feature index] = set of unique values */ @@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] ( override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) - extends Model[VectorIndexerModel] with VectorIndexerParams { + extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { + + import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { @@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] ( val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new VectorIndexerModelWriter(this) +} + +@Since("1.6.0") +object VectorIndexerModel extends MLReadable[VectorIndexerModel] { + + private[VectorIndexerModel] + class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter { + + private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.numFeatures, instance.categoryMaps) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] { + + private val className = classOf[VectorIndexerModel].getName + + override def load(path: String): VectorIndexerModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("numFeatures", "categoryMaps") + .head() + val numFeatures = data.getAs[Int](0) + val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1) + val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader + + @Since("1.6.0") + override def load(path: String): VectorIndexerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 708dbeef84db..a8d61b6dea00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,15 +17,17 @@ package org.apache.spark.ml.feature +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -90,7 +92,8 @@ private[feature] trait Word2VecBase extends Params * natural language processing or machine learning process. */ @Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase + with DefaultParamsWritable { def this() = this(Identifiable.randomUID("w2v")) @@ -139,6 +142,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } +@Since("1.6.0") +object Word2Vec extends DefaultParamsReadable[Word2Vec] { + + @Since("1.6.0") + override def load(path: String): Word2Vec = super.load(path) +} + /** * :: Experimental :: * Model fitted by [[Word2Vec]]. @@ -147,7 +157,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] class Word2VecModel private[ml] ( override val uid: String, @transient private val wordVectors: feature.Word2VecModel) - extends Model[Word2VecModel] with Word2VecBase { + extends Model[Word2VecModel] with Word2VecBase with MLWritable { + + import Word2VecModel._ /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and @@ -224,4 +236,49 @@ class Word2VecModel private[ml] ( val copied = new Word2VecModel(uid, wordVectors) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new Word2VecModelWriter(this) +} + +@Since("1.6.0") +object Word2VecModel extends MLReadable[Word2VecModel] { + + private[Word2VecModel] + class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter { + + private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class Word2VecModelReader extends MLReader[Word2VecModel] { + + private val className = classOf[Word2VecModel].getName + + override def load(path: String): Word2VecModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("wordIndex", "wordVectors") + .head() + val wordIndex = data.getAs[Map[String, Int]](0) + val wordVectors = data.getAs[Seq[Float]](1).toArray + val oldModel = new feature.Word2VecModel(wordIndex, wordVectors) + val model = new Word2VecModel(metadata.uid, oldModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[Word2VecModel] = new Word2VecModelReader + + @Since("1.6.0") + override def load(path: String): Word2VecModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7ab0d89d23a3..a47f27b0afb1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -432,9 +432,9 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Since("1.1.0") -class Word2VecModel private[mllib] ( - private val wordIndex: Map[String, Int], - private val wordVectors: Array[Float]) extends Serializable with Saveable { +class Word2VecModel private[spark] ( + private[spark] val wordIndex: Map[String, Int], + private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable { private val numWords = wordIndex.size // vectorSize: Dimension of each word's vector. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index e5a42967bd2c..7827db2794cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,13 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + test("Test Chi-Square selector") { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(vec1 ~== vec2 absTol 1e-1) } } + + test("ChiSqSelector read/write") { + val t = new ChiSqSelector() + .setFeaturesCol("myFeaturesCol") + .setLabelCol("myLabelCol") + .setOutputCol("myOutputCol") + .setNumTopFeatures(2) + testDefaultReadWrite(t) + } + + test("ChiSqSelectorModel read/write") { + val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) + val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.selectedFeatures === instance.selectedFeatures) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 30c500f87a76..5a21cd20ceed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} import org.apache.spark.sql.Row -class PCASuite extends SparkFunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PCA) @@ -65,4 +65,24 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") } } + + test("read/write") { + + def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { + assert(model1.pc === model2.pc) + } + val allParams: Map[String, Any] = Map( + "k" -> 3, + "inputCol" -> "features", + "outputCol" -> "pca_features" + ) + val data = Seq( + (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), + (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), + (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) + ) + val df = sqlContext.createDataFrame(data).toDF("id", "features") + val pca = new PCA().setK(3) + testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 8cb0a2cf14d3..67817fa4baf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,13 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest with Logging { import VectorIndexerSuite.FeatureData @@ -251,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L } } } + + test("VectorIndexer read/write") { + val t = new VectorIndexer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxCategories(30) + testDefaultReadWrite(t) + } + + test("VectorIndexerModel read/write") { + val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1, + 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2)) + val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.numFeatures === instance.numFeatures) + assert(newInstance.categoryMaps === instance.categoryMaps) + } } private[feature] object VectorIndexerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 23dfdaa9f8fc..a773244cd735 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Word2Vec) @@ -143,5 +143,31 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + + test("Word2Vec read/write") { + val t = new Word2Vec() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMaxIter(2) + .setMinCount(8) + .setNumPartitions(1) + .setSeed(42L) + .setStepSize(0.01) + .setVectorSize(100) + testDefaultReadWrite(t) + } + + test("Word2VecModel read/write") { + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val oldModel = new OldWord2VecModel(word2VecMap) + val instance = new Word2VecModel("myWord2VecModel", oldModel) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + } } From 7216f405454f6f3557b5b1f72df8f393605faf60 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 Nov 2015 22:14:01 -0800 Subject: [PATCH 816/896] [SPARK-11875][ML][PYSPARK] Update doc for PySpark HasCheckpointInterval * Update doc for PySpark ```HasCheckpointInterval``` that users can understand how to disable checkpoint. * Update doc for PySpark ```cacheNodeIds``` of ```DecisionTreeParams``` to notify the relationship between ```cacheNodeIds``` and ```checkpointInterval```. Author: Yanbo Liang Closes #9856 from yanboliang/spark-11875. --- python/pyspark/ml/param/_shared_params_code_gen.py | 6 ++++-- python/pyspark/ml/param/shared.py | 14 +++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 070c5db01ae7..0528dc1e3a6b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -118,7 +118,8 @@ def get$Name(self): ("inputCols", "input column names.", None), ("outputCol", "output column name.", "self.uid + '__output'"), ("numFeatures", "number of features.", None), - ("checkpointInterval", "checkpoint interval (>= 1).", None), + ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), ("seed", "random seed.", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms.", None), ("stepSize", "Step size to be used for each iteration of optimization.", None), @@ -157,7 +158,8 @@ def get$Name(self): ("maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation."), ("cacheNodeIds", "If false, the algorithm will pass trees to executors to match " + "instances with nodes. If true, the algorithm will cache node IDs for each instance. " + - "Caching can speed up training of deeper trees.")] + "Caching can speed up training of deeper trees. Users can set how often should the " + + "cache be checkpointed or disable it by setting checkpointInterval.")] decisionTreeCode = '''class DecisionTreeParams(Params): """ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4bdf2a8cc563..4d960801502c 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -325,16 +325,16 @@ def getNumFeatures(self): class HasCheckpointInterval(Params): """ - Mixin for param checkpointInterval: checkpoint interval (>= 1). + Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "checkpoint interval (>= 1).") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def __init__(self): super(HasCheckpointInterval, self).__init__() - #: param for checkpoint interval (>= 1). - self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1).") + #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") def setCheckpointInterval(self, value): """ @@ -636,7 +636,7 @@ class DecisionTreeParams(Params): minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.") minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def __init__(self): @@ -651,8 +651,8 @@ def __init__(self): self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") #: param for Maximum memory in MB allocated to histogram aggregation. self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") - #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. - self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. + self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.") def setMaxDepth(self, value): """ From 0fff8eb3e476165461658d4e16682ec64269fdfe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 19 Nov 2015 23:42:24 -0800 Subject: [PATCH 817/896] [SPARK-11869][ML] Clean up TempDirectory properly in ML tests Need to remove parent directory (```className```) rather than just tempDir (```className/random_name```) I tested this with IDFSuite, which has 2 read/write tests, and it fixes the problem. CC: mengxr Can you confirm this is fine? I believe it is since the same ```random_name``` is used for all tests in a suite; we basically have an extra unneeded level of nesting. Author: Joseph K. Bradley Closes #9851 from jkbradley/tempdir-cleanup. --- .../src/test/scala/org/apache/spark/ml/util/TempDirectory.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index 2742026a69c2..c8a0bb16247b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -35,7 +35,7 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => override def beforeAll(): Unit = { super.beforeAll() - _tempDir = Utils.createTempDir(this.getClass.getName) + _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) } override def afterAll(): Unit = { From 3e1d120cedb4bd9e1595e95d4d531cf61da6684d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 19 Nov 2015 23:43:18 -0800 Subject: [PATCH 818/896] [SPARK-11867] Add save/load for kmeans and naive bayes https://issues.apache.org/jira/browse/SPARK-11867 Author: Xusen Yin Closes #9849 from yinxusen/SPARK-11867. --- .../spark/ml/classification/NaiveBayes.scala | 68 +++++++++++++++++-- .../apache/spark/ml/clustering/KMeans.scala | 67 ++++++++++++++++-- .../ml/classification/NaiveBayesSuite.scala | 47 +++++++++++-- .../spark/ml/clustering/KMeansSuite.scala | 41 ++++++++--- 4 files changed, 195 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index a14dcecbaf5b..c512a2cb8bf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -17,12 +17,15 @@ package org.apache.spark.ml.classification +import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} +import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { @Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] - with NaiveBayesParams { + with NaiveBayesParams with DefaultParamsWritable { def this() = this(Identifiable.randomUID("nb")) @@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) } +@Since("1.6.0") +object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + + @Since("1.6.0") + override def load(path: String): NaiveBayes = super.load(path) +} + /** * :: Experimental :: * Model produced by [[NaiveBayes]] @@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] + with NaiveBayesParams with MLWritable { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] ( s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } + @Since("1.6.0") + override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this) } -private[ml] object NaiveBayesModel { +@Since("1.6.0") +object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Convert a model from the old API */ - def fromOld( + private[ml] def fromOld( oldModel: OldNaiveBayesModel, parent: NaiveBayes): NaiveBayesModel = { val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") @@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel { oldModel.theta.flatten, true) new NaiveBayesModel(uid, pi, theta) } + + @Since("1.6.0") + override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader + + @Since("1.6.0") + override def load(path: String): NaiveBayesModel = super.load(path) + + /** [[MLWriter]] instance for [[NaiveBayesModel]] */ + private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter { + + private case class Data(pi: Vector, theta: Matrix) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: pi, theta + val data = Data(instance.pi, instance.theta) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[NaiveBayesModel].getName + + override def load(path: String): NaiveBayesModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) + val model = new NaiveBayesModel(metadata.uid, pi, theta) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 509be6300239..71e968497500 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,10 +17,12 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.util._ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.{DataFrame, Row} - /** * Common params for KMeans and KMeansModel */ @@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe @Experimental class KMeansModel private[ml] ( @Since("1.5.0") override val uid: String, - private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + private val parentModel: MLlibKMeansModel) + extends Model[KMeansModel] with KMeansParams with MLWritable { @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { @@ -129,6 +131,52 @@ class KMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("1.6.0") + override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) +} + +@Since("1.6.0") +object KMeansModel extends MLReadable[KMeansModel] { + + @Since("1.6.0") + override def read: MLReader[KMeansModel] = new KMeansModelReader + + @Since("1.6.0") + override def load(path: String): KMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[KMeansModel]] */ + private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { + + private case class Data(clusterCenters: Array[Vector]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: cluster centers + val data = Data(instance.clusterCenters) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class KMeansModelReader extends MLReader[KMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[KMeansModel].getName + + override def load(path: String): KMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head() + val clusterCenters = data.getAs[Seq[Vector]](0).toArray + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -141,7 +189,7 @@ class KMeansModel private[ml] ( @Experimental class KMeans @Since("1.5.0") ( @Since("1.5.0") override val uid: String) - extends Estimator[KMeansModel] with KMeansParams { + extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable { setDefault( k -> 2, @@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") ( } } +@Since("1.6.0") +object KMeans extends DefaultParamsReadable[KMeans] { + + @Since("1.6.0") + override def load(path: String): KMeans = super.load(path) +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 98bc9511163e..082a6bcd211a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} +import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} + +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) -class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + } def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { @@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") } + + test("read/write") { + def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = { + assert(model.pi === model2.pi) + assert(model.theta === model2.theta) + } + val nb = new NaiveBayes() + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + } +} + +object NaiveBayesSuite { + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "smoothing" -> 0.1 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c05f90550d16..2724e51f31aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} private[clustering] case class TestRow(features: Vector) -object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext - val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) - .map(v => new TestRow(v)) - sql.createDataFrame(rdd) - } -} - -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) } + + test("read/write") { + def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val kmeans = new KMeans() + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + } +} + +object KMeansSuite { + def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = sql.sparkContext + val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) + .map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) } From a66142decee48bf5689fb7f4f33646d7bb1ac08d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 00:46:29 -0800 Subject: [PATCH 819/896] [SPARK-11877] Prevent agg. fallback conf. from leaking across test suites This patch fixes an issue where the `spark.sql.TungstenAggregate.testFallbackStartsAt` SQLConf setting was not properly reset / cleared at the end of `TungstenAggregationQueryWithControlledFallbackSuite`. This ended up causing test failures in HiveCompatibilitySuite in Maven builds by causing spilling to occur way too frequently. This configuration leak was inadvertently introduced during test cleanup in #9618. Author: Josh Rosen Closes #9857 from JoshRosen/clear-fallback-prop-in-test-teardown. --- .../execution/AggregationQuerySuite.scala | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6dde79f74d3d..39c0a2a0de04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -868,29 +868,27 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => - sqlContext.setConf( - "spark.sql.TungstenAggregate.testFallbackStartsAt", - fallbackStartsAt.toString) - - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = DataFrame(sqlContext, actual.logicalPlan) - - QueryTest.checkAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using TungstenAggregate with - |controlled fallback (it falls back to sort-based aggregation once it has processed - |$fallbackStartsAt input rows). The query is - |${actual.queryExecution} - | - |$errorMessage - """.stripMargin - - fail(newErrorMessage) - case None => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> fallbackStartsAt.toString) { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } } } } From 9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 20 Nov 2015 09:55:53 -0800 Subject: [PATCH 820/896] [SPARK-11852][ML] StandardScaler minor refactor ```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```. Author: Yanbo Liang Closes #9839 from yanboliang/standardScaler-refactor. --- .../spark/ml/feature/StandardScaler.scala | 60 +++++++++---------- .../ml/feature/StandardScalerSuite.scala | 11 ++-- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 6d545219ebf4..d76a9c6275e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType} private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { /** - * Centers the data with mean before scaling. + * Whether to center the data with mean before scaling. * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ - val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + val withMean: BooleanParam = new BooleanParam(this, "withMean", + "Whether to center data with mean") + + /** @group getParam */ + def getWithMean: Boolean = $(withMean) /** - * Scales the data to unit standard deviation. + * Whether to scale the data to unit standard deviation. * Default: true * @group param */ - val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") + val withStd: BooleanParam = new BooleanParam(this, "withStd", + "Whether to scale the data to unit standard deviation") + + /** @group getParam */ + def getWithStd: Boolean = $(withStd) + + setDefault(withMean -> false, withStd -> true) } /** @@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM def this() = this(Identifiable.randomUID("stdScal")) - setDefault(withMean -> false, withStd -> true) - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) + copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] { /** * :: Experimental :: * Model fitted by [[StandardScaler]]. + * + * @param std Standard deviation of the StandardScalerModel + * @param mean Mean of the StandardScalerModel */ @Experimental class StandardScalerModel private[ml] ( override val uid: String, - scaler: feature.StandardScalerModel) + val std: Vector, + val mean: Vector) extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ - /** Standard deviation of the StandardScalerModel */ - val std: Vector = scaler.std - - /** Mean of the StandardScalerModel */ - val mean: Vector = scaler.mean - - /** Whether to scale to unit standard deviation. */ - @Since("1.6.0") - def getWithStd: Boolean = scaler.withStd - - /** Whether to center data with mean. */ - @Since("1.6.0") - def getWithMean: Boolean = scaler.withMean - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -139,6 +137,7 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } dataset.withColumn($(outputCol), scale(col($(inputCol)))) } @@ -154,7 +153,7 @@ class StandardScalerModel private[ml] ( } override def copy(extra: ParamMap): StandardScalerModel = { - val copied = new StandardScalerModel(uid, scaler) + val copied = new StandardScalerModel(uid, std, mean) copyValues(copied, extra).setParent(parent) } @@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private[StandardScalerModel] class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter { - private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) + private case class Data(std: Vector, mean: Vector) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean) + val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) = - sqlContext.read.parquet(dataPath) - .select("std", "mean", "withStd", "withMean") - .head() - // This is very likely to change in the future because withStd and withMean should be params. - val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean) - val model = new StandardScalerModel(metadata.uid, oldModel) + val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + .select("std", "mean") + .head() + val model = new StandardScalerModel(metadata.uid, std, mean) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 49a4b2efe0c2..1eae125a524e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext test("params") { ParamsSuite.checkParams(new StandardScaler) - val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0)) - ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel)) + ParamsSuite.checkParams(new StandardScalerModel("empty", + Vectors.dense(1.0), Vectors.dense(2.0))) } test("Standardization with default parameter") { @@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("StandardScalerModel read/write") { - val oldModel = new feature.StandardScalerModel( - Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true) - val instance = new StandardScalerModel("myStandardScalerModel", oldModel) + val instance = new StandardScalerModel("myStandardScalerModel", + Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0)) val newInstance = testDefaultReadWrite(instance) assert(newInstance.std === instance.std) assert(newInstance.mean === instance.mean) - assert(newInstance.getWithStd === instance.getWithStd) - assert(newInstance.getWithMean === instance.getWithMean) } } From e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 20 Nov 2015 09:57:09 -0800 Subject: [PATCH 821/896] [SPARK-11689][ML] Add user guide and example code for LDA under spark.ml jira: https://issues.apache.org/jira/browse/SPARK-11689 Add simple user guide for LDA under spark.ml and example code under examples/. Use include_example to include example code in the user guide markdown. Check SPARK-11606 for instructions. Author: Yuhao Yang Closes #9722 from hhbyyh/ldaMLExample. --- docs/ml-clustering.md | 30 ++++++ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 + .../spark/examples/ml/JavaLDAExample.java | 94 +++++++++++++++++++ .../apache/spark/examples/ml/LDAExample.scala | 77 +++++++++++++++ 5 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 docs/ml-clustering.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md new file mode 100644 index 000000000000..1743ef43a6dd --- /dev/null +++ b/docs/ml-clustering.md @@ -0,0 +1,30 @@ +--- +layout: global +title: Clustering - ML +displayTitle: ML - Clustering +--- + +In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). + +## Latent Dirichlet allocation (LDA) + +`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, +and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +`EMLDAOptimizer` to a `DistributedLDAModel` if needed. + +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. + +
    +{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} +
    + +
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index be18a05361a1..6f35b30c3d4d 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,6 +40,7 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -950,4 +951,4 @@ model.transform(test) {% endhighlight %} - + \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 91e50ccfecec..54e35fcbb15a 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,6 +69,7 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) +* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java new file mode 100644 index 000000000000..b3a7d2eb2978 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.regex.Pattern; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example demonstrating LDA + * Run with + *
    + * bin/run-example ml.JavaLDAExample
    + * 
    + */ +public class JavaLDAExample { + + private static class ParseVector implements Function { + private static final Pattern separator = Pattern.compile(" "); + + @Override + public Row call(String line) { + String[] tok = separator.split(line); + double[] point = new double[tok.length]; + for (int i = 0; i < tok.length; ++i) { + point[i] = Double.parseDouble(tok[i]); + } + Vector[] points = {Vectors.dense(point)}; + return new GenericRow(points); + } + } + + public static void main(String[] args) { + + String inputFile = "data/mllib/sample_lda_data.txt"; + + // Parses the arguments + SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // Loads data + JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); + StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; + StructType schema = new StructType(fields); + DataFrame dataset = sqlContext.createDataFrame(points, schema); + + // Trains a LDA model + LDA lda = new LDA() + .setK(10) + .setMaxIter(10); + LDAModel model = lda.fit(dataset); + + System.out.println(model.logLikelihood(dataset)); + System.out.println(model.logPerplexity(dataset)); + + // Shows the result + DataFrame topics = model.describeTopics(3); + topics.show(false); + model.transform(dataset).show(false); + + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala new file mode 100644 index 000000000000..419ce3d87a6a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +// $example on$ +import org.apache.spark.ml.clustering.LDA +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.types.{StructField, StructType} +// $example off$ + +/** + * An example demonstrating a LDA of ML pipeline. + * Run with + * {{{ + * bin/run-example ml.LDAExample + * }}} + */ +object LDAExample { + + final val FEATURES_COL = "features" + + def main(args: Array[String]): Unit = { + + val input = "data/mllib/sample_lda_data.txt" + // Creates a Spark context and a SQL context + val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Loads data + val rowRDD = sc.textFile(input).filter(_.nonEmpty) + .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) + val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) + val dataset = sqlContext.createDataFrame(rowRDD, schema) + + // Trains a LDA model + val lda = new LDA() + .setK(10) + .setMaxIter(10) + .setFeaturesCol(FEATURES_COL) + val model = lda.fit(dataset) + val transformed = model.transform(dataset) + + val ll = model.logLikelihood(dataset) + val lp = model.logPerplexity(dataset) + + // describeTopics + val topics = model.describeTopics(3) + + // Shows the result + topics.show(false) + transformed.show(false) + + // $example off$ + sc.stop() + } +} +// scalastyle:on println From bef361c589c0a38740232fd8d0a45841e4fc969a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 20 Nov 2015 11:20:47 -0800 Subject: [PATCH 822/896] [SPARK-11876][SQL] Support printSchema in DataSet API DataSet APIs look great! However, I am lost when doing multiple level joins. For example, ``` val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2").printSchema() ``` The printed schema is like ``` root |-- _1: struct (nullable = true) | |-- _1: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) | |-- _2: struct (nullable = true) | | |-- _1: string (nullable = true) | | |-- _2: integer (nullable = true) |-- _2: struct (nullable = true) | |-- _1: string (nullable = true) | |-- _2: integer (nullable = true) ``` Personally, I think we need the printSchema function. Sometimes, I do not know how to specify the column, especially when their data types are mixed. For example, if I want to write the following select for the above multi-level join, I have to know the schema: ``` newDS.select(expr("_1._2._2 + 1").as[Int]).collect() ``` marmbrus rxin cloud-fan Do you have the same feeling? Author: gatorsmile Closes #9855 from gatorsmile/printSchemaDataSet. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 9 --------- .../scala/org/apache/spark/sql/execution/Queryable.scala | 9 +++++++++ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 98358127e270..7abcecaa2880 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -299,15 +299,6 @@ class DataFrame private[sql]( */ def columns: Array[String] = schema.fields.map(_.name) - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println - /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index e86a52c149a2..321e2c783537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -37,6 +37,15 @@ private[sql] trait Queryable { } } + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + /** * Prints the plans (logical and physical) to the console for debugging purposes. * @since 1.3.0 From 60bfb113325c71491f8dcf98b6036b0caa2144fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Nov 2015 11:43:45 -0800 Subject: [PATCH 823/896] [SPARK-11817][SQL] Truncating the fractional seconds to prevent inserting a NULL JIRA: https://issues.apache.org/jira/browse/SPARK-11817 Instead of return None, we should truncate the fractional seconds to prevent inserting NULL. Author: Liang-Chi Hsieh Closes #9834 from viirya/truncate-fractional-sec. --- .../apache/spark/sql/catalyst/util/DateTimeUtils.scala | 5 +++++ .../spark/sql/catalyst/util/DateTimeUtilsSuite.scala | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 17a5527f3fb2..2b9388291948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -327,6 +327,11 @@ object DateTimeUtils { return None } + // Instead of return None, we truncate the fractional seconds to prevent inserting NULL + if (segments(6) > 999999) { + segments(6) = segments(6).toString.take(6).toInt + } + if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || segments(5) < 0 || segments(5) > 59 || segments(6) < 0 || segments(6) > 999999 || segments(7) < 0 || segments(7) > 23 || segments(8) < 0 || segments(8) > 59) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index faca128badfd..0ce5a2fb6950 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -343,6 +343,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + assert(stringToTimestamp( + UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === + c.getTimeInMillis * 1000 + 123456) } test("hours") { From 3b9d2a347f9c796b90852173d84189834e499e25 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 12:04:42 -0800 Subject: [PATCH 824/896] [SPARK-11819][SQL] nice error message for missing encoder before this PR, when users try to get an encoder for an un-supported class, they will only get a very simple error message like `Encoder for type xxx is not supported`. After this PR, the error message become more friendly, for example: ``` No Encoder found for abc.xyz.NonEncodable - array element class: "abc.xyz.NonEncodable" - field (class: "scala.Array", name: "arrayField") - root class: "abc.xyz.AnotherClass" ``` Author: Wenchen Fan Closes #9810 from cloud-fan/error-message. --- .../spark/sql/catalyst/ScalaReflection.scala | 90 ++++++++++++++----- .../encoders/EncoderErrorMessageSuite.scala | 62 +++++++++++++ 2 files changed, 129 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 33ae700706da..918050b531c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType case _ => - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) className match { case "scala.Array" => val TypeRef(_, _, Seq(elementType)) = tpe @@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection { } } - /** Returns expressions for extracting all the fields from the given type. */ + /** + * Returns expressions for extracting all the fields from the given type. + * + * If the given type is not supported, i.e. there is no encoder can be built for this type, + * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain + * the type path walked so far and which class we are not supporting. + * There are 4 kinds of type path: + * * the root type: `root class: "abc.xyz.MyClass"` + * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"` + * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` + * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` + */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { - extractorFor(inputObject, localTypeOf[T]) match { + val tpe = localTypeOf[T] + val clsName = getClassNameFromType(tpe) + val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + extractorFor(inputObject, tpe, walkedTypePath) match { case s: CreateNamedStruct => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } @@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection { /** Helper for extracting internal fields from a case class. */ private def extractorFor( inputObject: Expression, - tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + tpe: `Type`, + walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { + + def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = silentSchemaFor(elementType) + if (isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here + // to trigger the type check. + extractorFor(inputObject, elementType, newPath) + + MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + } + } + if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { @@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection { // For non-primitives, we can just extract the object from the Option and then recurse. case other => - val className: String = optType.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(optType) val classObj = Utils.classForName(className) val optionObjectType = ObjectType(classObj) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath val unwrapped = UnwrapOption(optionObjectType, inputObject) expressions.If( IsNull(unwrapped), - expressions.Literal.create(null, schemaFor(optType).dataType), - extractorFor(unwrapped, optType)) + expressions.Literal.create(null, silentSchemaFor(optType).dataType), + extractorFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection { val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) case t if t <:< localTypeOf[Array[_]] => @@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } } - - private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = schemaFor(elementType) - if (isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - MapObjects(extractorFor(_, elementType), input, externalDataType) - } - } } /** @@ -561,7 +588,7 @@ trait ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className: String = tpe.erasure.typeSymbol.asClass.fullName + val className = getClassNameFromType(tpe) tpe match { case t if Utils.classIsLoadable(className) && Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => @@ -637,6 +664,23 @@ trait ScalaReflection { } } + /** + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. + */ + private def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) + } + + /** Returns the full class name for a type. */ + private def getClassNameFromType(tpe: `Type`): String = { + tpe.erasure.typeSymbol.asClass.fullName + } + /** * Returns classes of input parameters of scala function object. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala index 0b2a10bb04c1..8c766ef82992 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderErrorMessageSuite.scala @@ -17,9 +17,22 @@ package org.apache.spark.sql.catalyst.encoders +import scala.reflect.ClassTag + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +class NonEncodable(i: Int) + +case class ComplexNonEncodable1(name1: NonEncodable) + +case class ComplexNonEncodable2(name2: ComplexNonEncodable1) + +case class ComplexNonEncodable3(name3: Option[NonEncodable]) + +case class ComplexNonEncodable4(name4: Array[NonEncodable]) + +case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]]) class EncoderErrorMessageSuite extends SparkFunSuite { @@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite { intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] } intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] } } + + test("nice error message for missing encoder") { + val errorMsg1 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage + assert(errorMsg1.contains( + s"""root class: "${clsName[ComplexNonEncodable1]}"""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg2 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage + assert(errorMsg2.contains( + s"""root class: "${clsName[ComplexNonEncodable2]}"""")) + assert(errorMsg2.contains( + s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")""")) + assert(errorMsg1.contains( + s"""field (class: "${clsName[NonEncodable]}", name: "name1")""")) + + val errorMsg3 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage + assert(errorMsg3.contains( + s"""root class: "${clsName[ComplexNonEncodable3]}"""")) + assert(errorMsg3.contains( + s"""field (class: "scala.Option", name: "name3")""")) + assert(errorMsg3.contains( + s"""option value class: "${clsName[NonEncodable]}"""")) + + val errorMsg4 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage + assert(errorMsg4.contains( + s"""root class: "${clsName[ComplexNonEncodable4]}"""")) + assert(errorMsg4.contains( + s"""field (class: "scala.Array", name: "name4")""")) + assert(errorMsg4.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + + val errorMsg5 = + intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage + assert(errorMsg5.contains( + s"""root class: "${clsName[ComplexNonEncodable5]}"""")) + assert(errorMsg5.contains( + s"""field (class: "scala.Option", name: "name5")""")) + assert(errorMsg5.contains( + s"""option value class: "scala.Array"""")) + assert(errorMsg5.contains( + s"""array element class: "${clsName[NonEncodable]}"""")) + } + + private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName } From 652def318e47890bd0a0977dc982cc07f99fb06a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 13:17:35 -0800 Subject: [PATCH 825/896] [SPARK-11650] Reduce RPC timeouts to speed up slow AkkaUtilsSuite test This patch reduces some RPC timeouts in order to speed up the slow "AkkaUtilsSuite.remote fetch ssl on - untrusted server", which used to take two minutes to run. Author: Josh Rosen Closes #9869 from JoshRosen/SPARK-11650. --- core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 61601016e005..0af4b6098bb0 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -340,10 +340,11 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) val slaveConf = sparkSSLConfig() + .set("spark.rpc.askTimeout", "5s") + .set("spark.rpc.lookupTimeout", "5s") val securityManagerBad = new SecurityManager(slaveConf) val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) try { slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) fail("should receive either ActorNotFound or TimeoutException") From 9ed4ad4265cf9d3135307eb62dae6de0b220fc21 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 14:19:34 -0800 Subject: [PATCH 826/896] [SPARK-11724][SQL] Change casting between int and timestamp to consistently treat int in seconds. Hive has since changed this behavior as well. https://issues.apache.org/jira/browse/HIVE-3454 Author: Nong Li Author: Nong Li Author: Yin Huai Closes #9685 from nongli/spark-11724. --- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++-- .../sql/catalyst/expressions/CastSuite.scala | 16 +++++---- .../apache/spark/sql/DateFunctionsSuite.scala | 3 ++ ...esting-0-237a6af90a857da1efcbe98f6bbbf9d6} | 2 +- ... cast #3-0-76ee270337f664b36cacfc6528ac109 | 1 - ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 1 - ...cast #7-0-1d70654217035f8ce5f64344f4c5a80f | 1 - .../sql/hive/execution/HiveQuerySuite.scala | 34 +++++++++++++------ 8 files changed, 39 insertions(+), 25 deletions(-) rename sql/hive/src/test/resources/golden/{constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 => constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6} (52%) delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5564e242b047..533d17ea5c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -204,8 +204,8 @@ case class Cast(child: Expression, dataType: DataType) if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong } - // converting milliseconds to us - private[this] def longToTimestamp(t: Long): Long = t * 1000L + // converting seconds to us + private[this] def longToTimestamp(t: Long): Long = t * 1000000L // converting us to seconds private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 1000000L).toLong // converting us to seconds in double @@ -647,7 +647,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def decimalToTimestampCode(d: String): String = s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" - private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000000L" private[this] def timestampToIntegerCode(ts: String): String = s"java.lang.Math.floor((double) $ts / 1000000L)" private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f4db4da7646f..ab77a764483e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -258,8 +258,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from int 2") { checkEvaluation(cast(1, LongType), 1.toLong) - checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) - checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) + checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong) + checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong) checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) @@ -348,14 +348,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), - 0.toShort) + 5.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) @@ -479,10 +479,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.003f) checkEvaluation(cast(ts, DoubleType), 15.003) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation(cast(cast(tss, IntegerType), TimestampType), - DateTimeUtils.fromJavaTimestamp(ts)) - checkEvaluation(cast(cast(tss, LongType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + DateTimeUtils.fromJavaTimestamp(ts) * 1000) + checkEvaluation(cast(cast(tss, LongType), TimestampType), + DateTimeUtils.fromJavaTimestamp(ts) * 1000) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 241cbd011507..a61c3aa48a73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -448,6 +448,9 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer(sql(s"select cast ($now as timestamp)"), Row(new java.util.Date(now * 1000))) } test("to_unix_timestamp") { diff --git a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 similarity index 52% rename from sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 rename to sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 index 7c41615f8c18..a01c2622c68e 100644 --- a/sql/hive/src/test/resources/golden/constant null testing-0-9a02bc7de09bcabcbd4c91f54a814c20 +++ b/sql/hive/src/test/resources/golden/constant null testing-0-237a6af90a857da1efcbe98f6bbbf9d6 @@ -1 +1 @@ -1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL 1969-12-31 16:00:00.001 NULL 1 NULL +1 NULL 1 NULL 1.0 NULL true NULL 1 NULL 1.0 NULL 1 NULL 1 NULL 1 NULL 1970-01-01 NULL NULL 1 NULL diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 deleted file mode 100644 index d00491fd7e5b..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 +++ /dev/null @@ -1 +0,0 @@ -1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 deleted file mode 100644 index 84a31a5a6970..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ /dev/null @@ -1 +0,0 @@ --0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f deleted file mode 100644 index 3fbedf693b51..000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f +++ /dev/null @@ -1 +0,0 @@ --2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f0a7a6cc7a1e..8a5acaf3e10b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.sql.Timestamp import java.util.{Locale, TimeZone} import scala.util.Try @@ -248,12 +249,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |IF(TRUE, CAST(NULL AS BINARY), CAST("1" AS BINARY)) AS COL18, |IF(FALSE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL19, |IF(TRUE, CAST(NULL AS DATE), CAST("1970-01-01" AS DATE)) AS COL20, - |IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, - |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL22, - |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23, - |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL24 + |IF(TRUE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL21, + |IF(FALSE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL22, + |IF(TRUE, CAST(NULL AS DECIMAL), CAST(1 AS DECIMAL)) AS COL23 |FROM src LIMIT 1""".stripMargin) + test("constant null testing timestamp") { + val r1 = sql("SELECT IF(FALSE, CAST(NULL AS TIMESTAMP), CAST(1 AS TIMESTAMP)) AS COL20") + .collect().head + assert(new Timestamp(1000) == r1.getTimestamp(0)) + } + createQueryTest("constant array", """ |SELECT sort_array( @@ -603,26 +609,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // Jdk version leads to different query output for double, so not use createQueryTest here test("timestamp cast #1") { val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(0.001 == res.getDouble(0)) + assert(1 == res.getDouble(0)) } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #3", - "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #3") { + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) + } createQueryTest("timestamp cast #4", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #5", - "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) + } createQueryTest("timestamp cast #6", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("timestamp cast #7", - "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + test("timestamp cast #7") { + val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(-1200 == res.getInt(0)) + } createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") From be7a2cfd978143f6f265eca63e9e24f755bc9f22 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 20 Nov 2015 14:23:01 -0800 Subject: [PATCH 827/896] [SPARK-11870][STREAMING][PYSPARK] Rethrow the exceptions in TransformFunction and TransformFunctionSerializer TransformFunction and TransformFunctionSerializer don't rethrow the exception, so when any exception happens, it just return None. This will cause some weird NPE and confuse people. Author: Shixiong Zhu Closes #9847 from zsxwing/pyspark-streaming-exception. --- python/pyspark/streaming/tests.py | 16 ++++++++++++++++ python/pyspark/streaming/util.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 3403f6d20d78..a0e0267cafa5 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,22 @@ def func(dstream): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_failed_func(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("failed") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + return + + self.fail("a failed func should throw an error") + class StreamingListenerTests(PySparkStreamingTestCase): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283b..767c732eb90b 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -64,6 +64,7 @@ def call(self, milliseconds, jrdds): return r._jrdd except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunction(%s)" % self.func @@ -95,6 +96,7 @@ def dumps(self, id): return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() + raise def loads(self, data): try: @@ -102,6 +104,7 @@ def loads(self, data): return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer From 89fd9bd06160fa89dedbf685bfe159ffe4a06ec6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 20 Nov 2015 14:31:26 -0800 Subject: [PATCH 828/896] [SPARK-11887] Close PersistenceEngine at the end of PersistenceEngineSuite tests In PersistenceEngineSuite, we do not call `close()` on the PersistenceEngine at the end of the test. For the ZooKeeperPersistenceEngine, this causes us to leak a ZooKeeper client, causing the logs of unrelated tests to be periodically spammed with connection error messages from that client: ``` 15/11/20 05:13:35.789 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) INFO ClientCnxn: Opening socket connection to server localhost/127.0.0.1:15741. Will not attempt to authenticate using SASL (unknown error) 15/11/20 05:13:35.790 pool-1-thread-1-ScalaTest-running-PersistenceEngineSuite-SendThread(localhost:15741) WARN ClientCnxn: Session 0x15124ff48dd0000 for server null, unexpected error, closing socket connection and attempting reconnect java.net.ConnectException: Connection refused at sun.nio.ch.SocketChannelImpl.checkConnect(Native Method) at sun.nio.ch.SocketChannelImpl.finishConnect(SocketChannelImpl.java:739) at org.apache.zookeeper.ClientCnxnSocketNIO.doTransport(ClientCnxnSocketNIO.java:350) at org.apache.zookeeper.ClientCnxn$SendThread.run(ClientCnxn.java:1068) ``` This patch fixes this by using a `finally` block. Author: Josh Rosen Closes #9864 from JoshRosen/close-zookeeper-client-in-tests. --- .../master/PersistenceEngineSuite.scala | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 34775577de8a..7a4472867568 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -63,56 +63,60 @@ class PersistenceEngineSuite extends SparkFunSuite { conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { val serializer = new JavaSerializer(conf) val persistenceEngine = persistenceEngineCreator(serializer) - persistenceEngine.persist("test_1", "test_1_value") - assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.persist("test_2", "test_2_value") - assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) - persistenceEngine.unpersist("test_1") - assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) - persistenceEngine.unpersist("test_2") - assert(persistenceEngine.read[String]("test_").isEmpty) - - // Test deserializing objects that contain RpcEndpointRef - val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) try { - // Create a real endpoint so that we can test RpcEndpointRef deserialization - val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { - override val rpcEnv: RpcEnv = testRpcEnv - }) - - val workerToPersist = new WorkerInfo( - id = "test_worker", - host = "127.0.0.1", - port = 10000, - cores = 0, - memory = 0, - endpoint = workerEndpoint, - webUiPort = 0, - publicAddress = "" - ) - - persistenceEngine.addWorker(workerToPersist) - - val (storedApps, storedDrivers, storedWorkers) = - persistenceEngine.readPersistedData(testRpcEnv) - - assert(storedApps.isEmpty) - assert(storedDrivers.isEmpty) - - // Check deserializing WorkerInfo - assert(storedWorkers.size == 1) - val recoveryWorkerInfo = storedWorkers.head - assert(workerToPersist.id === recoveryWorkerInfo.id) - assert(workerToPersist.host === recoveryWorkerInfo.host) - assert(workerToPersist.port === recoveryWorkerInfo.port) - assert(workerToPersist.cores === recoveryWorkerInfo.cores) - assert(workerToPersist.memory === recoveryWorkerInfo.memory) - assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) - assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) - assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val testRpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = testRpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = testRpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = + persistenceEngine.readPersistedData(testRpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + testRpcEnv.shutdown() + testRpcEnv.awaitTermination() + } } finally { - testRpcEnv.shutdown() - testRpcEnv.awaitTermination() + persistenceEngine.close() } } From 03ba56d78f50747710d01c27d409ba2be42ae557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Fri, 20 Nov 2015 14:45:40 -0800 Subject: [PATCH 829/896] [SPARK-11716][SQL] UDFRegistration just drops the input type when re-creating the UserDefinedFunction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://issues.apache.org/jira/browse/SPARK-11716 This is one is #9739 and a regression test. When commit it, please make sure the author is jbonofre. You can find the original PR at https://github.com/apache/spark/pull/9739 closes #9739 Author: Jean-Baptiste Onofré Author: Yin Huai Closes #9868 from yhuai/SPARK-11716. --- .../apache/spark/sql/UDFRegistration.scala | 48 +++++++++---------- .../scala/org/apache/spark/sql/UDFSuite.scala | 15 ++++++ 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fc4d0938c533..051694c0d43a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -88,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try($inputTypes).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) }""") } @@ -120,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -133,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -146,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -159,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -172,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -185,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -211,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -224,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -237,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -250,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -263,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -276,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -289,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -302,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -315,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -328,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -341,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -367,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -380,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -393,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } /** @@ -406,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + UserDefinedFunction(func, dataType, inputTypes) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 9837fa6bdb35..fd736718af12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -232,4 +232,19 @@ class UDFSuite extends QueryTest with SharedSQLContext { | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp """.stripMargin).toDF(), complexData.select("m", "a", "b")) } + + test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { + val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + + // Without the fix, this will fail because we fail to cast data type of b to string + // because myUDF does not know its input data type. With the fix, this query should not + // fail. + checkAnswer( + testData2.select(myUDF($"a", $"b").as("t")), + testData2.selectExpr("struct(a, b)")) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(), + testData2) + } } From a6239d587c638691f52eca3eee905c53fbf35a12 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 20 Nov 2015 15:10:55 -0800 Subject: [PATCH 830/896] [SPARK-11756][SPARKR] Fix use of aliases - SparkR can not output help information for SparkR:::summary correctly Fix use of aliases and changes uses of rdname and seealso `aliases` is the hint for `?` - it should not be linked to some other name - those should be seealso https://cran.r-project.org/web/packages/roxygen2/vignettes/rd.html Clean up usage on family, as multiple use of family with the same rdname is causing duplicated See Also html blocks (like http://spark.apache.org/docs/latest/api/R/count.html) Also changing some rdname for dplyr-like variant for better R user visibility in R doc, eg. rbind, summary, mutate, summarize shivaram yanboliang Author: felixcheung Closes #9750 from felixcheung/rdocaliases. --- R/pkg/R/DataFrame.R | 96 ++++++++++++--------------------------------- R/pkg/R/broadcast.R | 1 - R/pkg/R/generics.R | 12 +++--- R/pkg/R/group.R | 12 +++--- 4 files changed, 37 insertions(+), 84 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 06b0108b1389..8a13e7a36766 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -254,7 +254,6 @@ setMethod("dtypes", #' @family DataFrame functions #' @rdname columns #' @name columns -#' @aliases names #' @export #' @examples #'\dontrun{ @@ -272,7 +271,6 @@ setMethod("columns", }) }) -#' @family DataFrame functions #' @rdname columns #' @name names setMethod("names", @@ -281,7 +279,6 @@ setMethod("names", columns(x) }) -#' @family DataFrame functions #' @rdname columns #' @name names<- setMethod("names<-", @@ -533,14 +530,8 @@ setMethod("distinct", dataFrame(sdf) }) -#' @title Distinct rows in a DataFrame -# -#' @description Returns a new DataFrame containing distinct rows in this DataFrame -#' -#' @family DataFrame functions -#' @rdname unique +#' @rdname distinct #' @name unique -#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -557,7 +548,7 @@ setMethod("unique", #' #' @family DataFrame functions #' @rdname sample -#' @aliases sample_frac +#' @name sample #' @export #' @examples #'\dontrun{ @@ -579,7 +570,6 @@ setMethod("sample", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname sample #' @name sample_frac setMethod("sample_frac", @@ -589,16 +579,15 @@ setMethod("sample_frac", sample(x, withReplacement, fraction) }) -#' Count +#' nrow #' #' Returns the number of rows in a DataFrame #' #' @param x A SparkSQL DataFrame #' #' @family DataFrame functions -#' @rdname count +#' @rdname nrow #' @name count -#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -614,14 +603,8 @@ setMethod("count", callJMethod(x@sdf, "count") }) -#' @title Number of rows for a DataFrame -#' @description Returns number of rows in a DataFrames -#' #' @name nrow -#' -#' @family DataFrame functions #' @rdname nrow -#' @aliases count setMethod("nrow", signature(x = "DataFrame"), function(x) { @@ -870,7 +853,6 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @aliases group_by #' @family DataFrame functions #' @rdname groupBy #' @name groupBy @@ -896,7 +878,6 @@ setMethod("groupBy", groupedData(sgd) }) -#' @family DataFrame functions #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -913,7 +894,6 @@ setMethod("group_by", #' @family DataFrame functions #' @rdname agg #' @name agg -#' @aliases summarize #' @export setMethod("agg", signature(x = "DataFrame"), @@ -921,7 +901,6 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @family DataFrame functions #' @rdname agg #' @name summarize setMethod("summarize", @@ -1092,7 +1071,6 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @family DataFrame functions #' @rdname subset #' @name subset -#' @aliases [ #' @family subsetting functions #' @examples #' \dontrun{ @@ -1216,7 +1194,7 @@ setMethod("selectExpr", #' @family DataFrame functions #' @rdname withColumn #' @name withColumn -#' @aliases mutate transform +#' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1231,7 +1209,6 @@ setMethod("withColumn", function(x, colName, col) { select(x, x$"*", alias(col, colName)) }) - #' Mutate #' #' Return a new DataFrame with the specified columns added. @@ -1240,9 +1217,9 @@ setMethod("withColumn", #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name mutate -#' @aliases withColumn transform +#' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ @@ -1273,17 +1250,15 @@ setMethod("mutate", }) #' @export -#' @family DataFrame functions -#' @rdname withColumn +#' @rdname mutate #' @name transform -#' @aliases withColumn mutate setMethod("transform", signature(`_data` = "DataFrame"), function(`_data`, ...) { mutate(`_data`, ...) }) -#' WithColumnRenamed +#' rename #' #' Rename an existing column in a DataFrame. #' @@ -1292,8 +1267,9 @@ setMethod("transform", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @family DataFrame functions -#' @rdname withColumnRenamed +#' @rdname rename #' @name withColumnRenamed +#' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ @@ -1316,17 +1292,9 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' Rename -#' -#' Rename an existing column in a DataFrame. -#' -#' @param x A DataFrame -#' @param newCol A named pair of the form new_column_name = existing_column -#' @return A DataFrame with the column name changed. -#' @family DataFrame functions -#' @rdname withColumnRenamed +#' @param newColPair A named pair of the form new_column_name = existing_column +#' @rdname rename #' @name rename -#' @aliases withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1371,7 +1339,6 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @family DataFrame functions #' @rdname arrange #' @name arrange -#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1395,8 +1362,8 @@ setMethod("arrange", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname arrange +#' @name arrange #' @export setMethod("arrange", signature(x = "DataFrame", col = "character"), @@ -1427,9 +1394,9 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) -#' @family DataFrame functions #' @rdname arrange -#' @name orderby +#' @name orderBy +#' @export setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1492,6 +1459,7 @@ setMethod("where", #' @family DataFrame functions #' @rdname join #' @name join +#' @seealso \link{merge} #' @export #' @examples #'\dontrun{ @@ -1528,9 +1496,7 @@ setMethod("join", dataFrame(sdf) }) -#' #' @name merge -#' @aliases join #' @title Merges two data frames #' @param x the first data frame to be joined #' @param y the second data frame to be joined @@ -1550,6 +1516,7 @@ setMethod("join", #' outer join will be returned. #' @family DataFrame functions #' @rdname merge +#' @seealso \link{join} #' @export #' @examples #'\dontrun{ @@ -1671,7 +1638,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { cols } -#' UnionAll +#' rbind #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. @@ -1681,7 +1648,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @family DataFrame functions -#' @rdname unionAll +#' @rdname rbind #' @name unionAll #' @export #' @examples @@ -1700,13 +1667,11 @@ setMethod("unionAll", }) #' @title Union two or more DataFrames -#' #' @description Returns a new DataFrame containing rows of all parameters. #' -#' @family DataFrame functions #' @rdname rbind #' @name rbind -#' @aliases unionAll +#' @export setMethod("rbind", signature(... = "DataFrame"), function(x, ..., deparse.level = 1) { @@ -1795,7 +1760,6 @@ setMethod("except", #' @family DataFrame functions #' @rdname write.df #' @name write.df -#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1828,7 +1792,6 @@ setMethod("write.df", callJMethod(df@sdf, "save", source, jmode, options) }) -#' @family DataFrame functions #' @rdname write.df #' @name saveDF #' @export @@ -1891,7 +1854,7 @@ setMethod("saveAsTable", callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) }) -#' describe +#' summary #' #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. @@ -1901,9 +1864,8 @@ setMethod("saveAsTable", #' @param ... Additional expressions #' @return A DataFrame #' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe -#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1923,8 +1885,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @family DataFrame functions -#' @rdname describe +#' @rdname summary #' @name describe setMethod("describe", signature(x = "DataFrame"), @@ -1934,11 +1895,6 @@ setMethod("describe", dataFrame(sdf) }) -#' @title Summary -#' -#' @description Computes statistics for numeric columns of the DataFrame -#' -#' @family DataFrame functions #' @rdname summary #' @name summary setMethod("summary", @@ -1966,7 +1922,6 @@ setMethod("summary", #' @family DataFrame functions #' @rdname nafunctions #' @name dropna -#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1993,7 +1948,6 @@ setMethod("dropna", dataFrame(sdf) }) -#' @family DataFrame functions #' @rdname nafunctions #' @name na.omit #' @export @@ -2019,9 +1973,7 @@ setMethod("na.omit", #' type are ignored. For example, if value is a character, and #' subset contains a non-character column, then the non-character #' column is simply ignored. -#' @return A DataFrame #' -#' @family DataFrame functions #' @rdname nafunctions #' @name fillna #' @export diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 2403925b267c..38f0eed95e06 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -51,7 +51,6 @@ Broadcast <- function(id, value, jBroadcastRef, objName) { # # @param bcast The broadcast variable to get # @rdname broadcast -# @aliases value,Broadcast-method setMethod("value", signature(bcast = "Broadcast"), function(bcast) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 71004a05ba61..1b3f10ea0464 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -397,7 +397,7 @@ setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) #' @export setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) -#' @rdname describe +#' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -459,11 +459,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' rdname merge +#' @rdname merge #' @export setGeneric("merge") -#' @rdname withColumn +#' @rdname mutate #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) @@ -475,7 +475,7 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) @@ -553,7 +553,7 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname unionAll +#' @rdname rbind #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) @@ -565,7 +565,7 @@ setGeneric("where", function(x, condition) { standardGeneric("where") }) #' @export setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") }) -#' @rdname withColumnRenamed +#' @rdname rename #' @export setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index e5f702faee65..23b49aebda05 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -68,7 +68,7 @@ setMethod("count", dataFrame(callJMethod(x@sgd, "count")) }) -#' Agg +#' summarize #' #' Aggregates on the entire DataFrame without groups. #' The resulting DataFrame will also contain the grouping columns. @@ -78,12 +78,14 @@ setMethod("count", #' #' @param x a GroupedData #' @return a DataFrame -#' @rdname agg +#' @rdname summarize +#' @name agg #' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' -#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum +#' df4 <- summarize(df, ageSum = max(df$age)) #' } setMethod("agg", signature(x = "GroupedData"), @@ -110,8 +112,8 @@ setMethod("agg", dataFrame(sdf) }) -#' @rdname agg -#' @aliases agg +#' @rdname summarize +#' @name summarize setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { From 4b84c72dfbb9ddb415fee35f69305b5d7b280891 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:17:17 -0800 Subject: [PATCH 831/896] [SPARK-11636][SQL] Support classes defined in the REPL with Encoders #theScaryParts (i.e. changes to the repl, executor classloaders and codegen)... Author: Michael Armbrust Author: Yin Huai Closes #9825 from marmbrus/dataset-replClasses2. --- .../org/apache/spark/repl/SparkIMain.scala | 14 +++++++---- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++++ .../spark/repl/ExecutorClassLoader.scala | 8 ++++++- .../expressions/codegen/CodeGenerator.scala | 4 ++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 4ee605fd7f11..829b12269fd2 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1221,10 +1221,16 @@ import org.apache.spark.annotation.DeveloperApi ) } - val preamble = """ - |class %s extends Serializable { - | %s%s%s - """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute)) + val preamble = s""" + |class ${lineRep.readName} extends Serializable { + | ${envLines.map(" " + _ + ";\n").mkString} + | $importsPreamble + | + | // If we need to construct any objects defined in the REPL on an executor we will need + | // to pass the outer scope to the appropriate encoder. + | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this) + | ${indentCode(toCompute)} + """.stripMargin val postamble = importsTrailer + "\n}" + "\n" + "object " + lineRep.readName + " {\n" + " val INSTANCE = new " + lineRep.readName + "();\n" + diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 5674dcd669be..081aa03002cc 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite { |import sqlContext.implicits._ |case class TestCaseClass(value: Int) |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect() + | + |// Test Dataset Serialization in the REPL + |Seq(TestCaseClass(1)).toDS().collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("java.lang.ClassNotFoundException", output) } + test("Datasets and encoders") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |val simpleSum = new Aggregator[Int, Int, Int] with Serializable { + | def zero: Int = 0 // The initial value. + | def reduce(b: Int, a: Int) = b + a // Add an element to the running total + | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. + | def finish(b: Int) = b // Return the final result. + |}.toColumn + | + |val ds = Seq(1, 2, 3, 4).toDS() + |ds.select(simpleSum).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3d2d235a00c9..a976e96809cb 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -65,7 +65,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader case e: ClassNotFoundException => { val classOption = findClassLocally(name) classOption match { - case None => throw new ClassNotFoundException(name, e) + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) case Some(a) => a } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1b7260cdfe51..2f3d6aeb86c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ - +import org.apache.spark.util.Utils /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -536,7 +536,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( From ed47b1e660b830e2d4fac8d6df93f634b260393c Mon Sep 17 00:00:00 2001 From: Vikas Nelamangala Date: Fri, 20 Nov 2015 15:18:41 -0800 Subject: [PATCH 832/896] [SPARK-11549][DOCS] Replace example code in mllib-evaluation-metrics.md using include_example Author: Vikas Nelamangala Closes #9689 from vikasnp/master. --- docs/mllib-evaluation-metrics.md | 940 +----------------- ...avaBinaryClassificationMetricsExample.java | 113 +++ ...ultiLabelClassificationMetricsExample.java | 80 ++ ...ulticlassClassificationMetricsExample.java | 97 ++ .../mllib/JavaRankingMetricsExample.java | 176 ++++ .../mllib/JavaRegressionMetricsExample.java | 91 ++ .../binary_classification_metrics_example.py | 55 + .../mllib/multi_class_metrics_example.py | 69 ++ .../mllib/multi_label_metrics_example.py | 61 ++ .../python/mllib/ranking_metrics_example.py | 55 + .../mllib/regression_metrics_example.py | 59 ++ .../BinaryClassificationMetricsExample.scala | 103 ++ .../mllib/MultiLabelMetricsExample.scala | 69 ++ .../mllib/MulticlassMetricsExample.scala | 99 ++ .../mllib/RankingMetricsExample.scala | 110 ++ .../mllib/RegressionMetricsExample.scala | 67 ++ 16 files changed, 1319 insertions(+), 925 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java create mode 100644 examples/src/main/python/mllib/binary_classification_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_class_metrics_example.py create mode 100644 examples/src/main/python/mllib/multi_label_metrics_example.py create mode 100644 examples/src/main/python/mllib/ranking_metrics_example.py create mode 100644 examples/src/main/python/mllib/regression_metrics_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index f73eff637dc3..6924037b941f 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -104,214 +104,21 @@ data, and evaluate the performance of the algorithm by several binary evaluation
    Refer to the [`LogisticRegressionWithLBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS) and [`BinaryClassificationMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.BinaryClassificationMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training) - -// Clear the prediction threshold so the model will return probabilities -model.clearThreshold - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new BinaryClassificationMetrics(predictionAndLabels) - -// Precision by threshold -val precision = metrics.precisionByThreshold -precision.foreach { case (t, p) => - println(s"Threshold: $t, Precision: $p") -} - -// Recall by threshold -val recall = metrics.recallByThreshold -recall.foreach { case (t, r) => - println(s"Threshold: $t, Recall: $r") -} - -// Precision-Recall Curve -val PRC = metrics.pr - -// F-measure -val f1Score = metrics.fMeasureByThreshold -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 1") -} - -val beta = 0.5 -val fScore = metrics.fMeasureByThreshold(beta) -f1Score.foreach { case (t, f) => - println(s"Threshold: $t, F-score: $f, Beta = 0.5") -} - -// AUPRC -val auPRC = metrics.areaUnderPR -println("Area under precision-recall curve = " + auPRC) - -// Compute thresholds used in ROC and PR curves -val thresholds = precision.map(_._1) - -// ROC Curve -val roc = metrics.roc - -// AUROC -val auROC = metrics.areaUnderROC -println("Area under ROC = " + auROC) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala %}
    Refer to the [`LogisticRegressionModel` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionModel.html) and [`LogisticRegressionWithLBFGS` Java docs](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class BinaryClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_binary_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(2) - .run(training.rdd()); - - // Clear the prediction threshold so the model will return probabilities - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); - - // Precision by threshold - JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); - - // Recall by threshold - JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); - - // F Score by threshold - JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); - - JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); - - // Precision-recall curve - JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); - - // Thresholds - JavaRDD thresholds = precision.map( - new Function, Double>() { - public Double call (Tuple2 t) { - return new Double(t._1().toString()); - } - } - ); - - // ROC Curve - JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); - - // AUPRC - System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); - - // AUROC - System.out.println("Area under ROC = " + metrics.areaUnderROC()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java %}
    Refer to the [`BinaryClassificationMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.BinaryClassificationMetrics) and [`LogisticRegressionWithLBFGS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.classification.LogisticRegressionWithLBFGS) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils - -# Several of the methods available in scala are currently missing from pyspark - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = BinaryClassificationMetrics(predictionAndLabels) - -# Area under precision-recall curve -print("Area under PR = %s" % metrics.areaUnderPR) - -# Area under ROC curve -print("Area under ROC = %s" % metrics.areaUnderROC) - -{% endhighlight %} - +{% include_example python/mllib/binary_classification_metrics_example.py %}
    @@ -433,204 +240,21 @@ the data, and evaluate the performance of the algorithm by several multiclass cl
    Refer to the [`MulticlassMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MulticlassMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils - -// Load training data in LIBSVM format -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -// Split data into training (60%) and test (40%) -val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) -training.cache() - -// Run training algorithm to build the model -val model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training) - -// Compute raw scores on the test set -val predictionAndLabels = test.map { case LabeledPoint(label, features) => - val prediction = model.predict(features) - (prediction, label) -} - -// Instantiate metrics object -val metrics = new MulticlassMetrics(predictionAndLabels) - -// Confusion matrix -println("Confusion matrix:") -println(metrics.confusionMatrix) - -// Overall Statistics -val precision = metrics.precision -val recall = metrics.recall // same as true positive rate -val f1Score = metrics.fMeasure -println("Summary Statistics") -println(s"Precision = $precision") -println(s"Recall = $recall") -println(s"F1 Score = $f1Score") - -// Precision by label -val labels = metrics.labels -labels.foreach { l => - println(s"Precision($l) = " + metrics.precision(l)) -} - -// Recall by label -labels.foreach { l => - println(s"Recall($l) = " + metrics.recall(l)) -} - -// False positive rate by label -labels.foreach { l => - println(s"FPR($l) = " + metrics.falsePositiveRate(l)) -} - -// F-measure by label -labels.foreach { l => - println(s"F1-Score($l) = " + metrics.fMeasure(l)) -} - -// Weighted stats -println(s"Weighted precision: ${metrics.weightedPrecision}") -println(s"Weighted recall: ${metrics.weightedRecall}") -println(s"Weighted F1 score: ${metrics.weightedFMeasure}") -println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala %}
    Refer to the [`MulticlassMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MulticlassMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; -import org.apache.spark.mllib.evaluation.MulticlassMetrics; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class MulticlassClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); - JavaRDD training = splits[0].cache(); - JavaRDD test = splits[1]; - - // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() - .setNumClasses(3) - .run(training.rdd()); - - // Compute raw scores on the test set. - JavaRDD> predictionAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2(prediction, p.label()); - } - } - ); - - // Get evaluation metrics. - MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - - // Confusion matrix - Matrix confusion = metrics.confusionMatrix(); - System.out.println("Confusion matrix: \n" + confusion); - - // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length; i++) { - System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); - } - - //Weighted stats - System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); - System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); - System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); - System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); - - // Save and load model - model.save(sc, "myModelPath"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); - } -} - -{% endhighlight %} + {% include_example java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java %}
    Refer to the [`MulticlassMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MulticlassMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS -from pyspark.mllib.util import MLUtils -from pyspark.mllib.evaluation import MulticlassMetrics - -# Load training data in LIBSVM format -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") - -# Split data into training (60%) and test (40%) -training, test = data.randomSplit([0.6, 0.4], seed = 11L) -training.cache() - -# Run training algorithm to build the model -model = LogisticRegressionWithLBFGS.train(training, numClasses=3) - -# Compute raw scores on the test set -predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) - -# Instantiate metrics object -metrics = MulticlassMetrics(predictionAndLabels) - -# Overall statistics -precision = metrics.precision() -recall = metrics.recall() -f1Score = metrics.fMeasure() -print("Summary Stats") -print("Precision = %s" % precision) -print("Recall = %s" % recall) -print("F1 Score = %s" % f1Score) - -# Statistics by class -labels = data.map(lambda lp: lp.label).distinct().collect() -for label in sorted(labels): - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) - -# Weighted stats -print("Weighted recall = %s" % metrics.weightedRecall) -print("Weighted precision = %s" % metrics.weightedPrecision) -print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) -print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) -print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) -{% endhighlight %} +{% include_example python/mllib/multi_class_metrics_example.py %}
    @@ -766,154 +390,21 @@ True classes:
    Refer to the [`MultilabelMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.MultilabelMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.MultilabelMetrics -import org.apache.spark.rdd.RDD; - -val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( - Seq((Array(0.0, 1.0), Array(0.0, 2.0)), - (Array(0.0, 2.0), Array(0.0, 1.0)), - (Array(), Array(0.0)), - (Array(2.0), Array(2.0)), - (Array(2.0, 0.0), Array(2.0, 0.0)), - (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), - (Array(1.0), Array(1.0, 2.0))), 2) - -// Instantiate metrics object -val metrics = new MultilabelMetrics(scoreAndLabels) - -// Summary stats -println(s"Recall = ${metrics.recall}") -println(s"Precision = ${metrics.precision}") -println(s"F1 measure = ${metrics.f1Measure}") -println(s"Accuracy = ${metrics.accuracy}") - -// Individual label stats -metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) -metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) -metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) - -// Micro stats -println(s"Micro recall = ${metrics.microRecall}") -println(s"Micro precision = ${metrics.microPrecision}") -println(s"Micro F1 measure = ${metrics.microF1Measure}") - -// Hamming loss -println(s"Hamming loss = ${metrics.hammingLoss}") - -// Subset accuracy -println(s"Subset accuracy = ${metrics.subsetAccuracy}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala %}
    Refer to the [`MultilabelMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/MultilabelMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.evaluation.MultilabelMetrics; -import org.apache.spark.SparkConf; -import java.util.Arrays; -import java.util.List; - -public class MultilabelClassification { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - - List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) - ); - JavaRDD> scoreAndLabels = sc.parallelize(data); - - // Instantiate metrics object - MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); - - // Summary stats - System.out.format("Recall = %f\n", metrics.recall()); - System.out.format("Precision = %f\n", metrics.precision()); - System.out.format("F1 measure = %f\n", metrics.f1Measure()); - System.out.format("Accuracy = %f\n", metrics.accuracy()); - - // Stats by labels - for (int i = 0; i < metrics.labels().length - 1; i++) { - System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); - System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); - System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); - } - - // Micro stats - System.out.format("Micro recall = %f\n", metrics.microRecall()); - System.out.format("Micro precision = %f\n", metrics.microPrecision()); - System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); - - // Hamming loss - System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); - - // Subset accuracy - System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); - - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java %}
    Refer to the [`MultilabelMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.MultilabelMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.evaluation import MultilabelMetrics - -scoreAndLabels = sc.parallelize([ - ([0.0, 1.0], [0.0, 2.0]), - ([0.0, 2.0], [0.0, 1.0]), - ([], [0.0]), - ([2.0], [2.0]), - ([2.0, 0.0], [2.0, 0.0]), - ([0.0, 1.0, 2.0], [0.0, 1.0]), - ([1.0], [1.0, 2.0])]) - -# Instantiate metrics object -metrics = MultilabelMetrics(scoreAndLabels) - -# Summary stats -print("Recall = %s" % metrics.recall()) -print("Precision = %s" % metrics.precision()) -print("F1 measure = %s" % metrics.f1Measure()) -print("Accuracy = %s" % metrics.accuracy) - -# Individual label stats -labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() -for label in labels: - print("Class %s precision = %s" % (label, metrics.precision(label))) - print("Class %s recall = %s" % (label, metrics.recall(label))) - print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) - -# Micro stats -print("Micro precision = %s" % metrics.microPrecision) -print("Micro recall = %s" % metrics.microRecall) -print("Micro F1 measure = %s" % metrics.microF1Measure) - -# Hamming loss -print("Hamming loss = %s" % metrics.hammingLoss) - -# Subset accuracy -print("Subset accuracy = %s" % metrics.subsetAccuracy) - -{% endhighlight %} +{% include_example python/mllib/multi_label_metrics_example.py %}
    @@ -1027,280 +518,21 @@ expanded world of non-positive weights are "the same as never having interacted
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RankingMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} -import org.apache.spark.mllib.recommendation.{ALS, Rating} - -// Read in the ratings data -val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => - val fields = line.split("::") - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) -}.cache() - -// Map ratings to 1 or 0, 1 indicating a movie that should be recommended -val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() - -// Summarize ratings -val numRatings = ratings.count() -val numUsers = ratings.map(_.user).distinct().count() -val numMovies = ratings.map(_.product).distinct().count() -println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") - -// Build the model -val numIterations = 10 -val rank = 10 -val lambda = 0.01 -val model = ALS.train(ratings, rank, numIterations, lambda) - -// Define a function to scale ratings from 0 to 1 -def scaledRating(r: Rating): Rating = { - val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) - Rating(r.user, r.product, scaledRating) -} - -// Get sorted top ten predictions for each user and then scale from [0, 1] -val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => - (user, recs.map(scaledRating)) -} - -// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document -// Compare with top ten most relevant documents -val userMovies = binarizedRatings.groupBy(_.user) -val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => - (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) -} - -// Instantiate metrics object -val metrics = new RankingMetrics(relevantDocuments) - -// Precision at K -Array(1, 3, 5).foreach{ k => - println(s"Precision at $k = ${metrics.precisionAt(k)}") -} - -// Mean average precision -println(s"Mean average precision = ${metrics.meanAveragePrecision}") - -// Normalized discounted cumulative gain -Array(1, 3, 5).foreach{ k => - println(s"NDCG at $k = ${metrics.ndcgAt(k)}") -} - -// Get predictions for each data point -val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) -val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) -val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => - (predicted, actual) -} - -// Get the RMSE using regression metrics -val regressionMetrics = new RegressionMetrics(predictionsAndLabels) -println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${regressionMetrics.r2}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) and [`RankingMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RankingMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.rdd.RDD; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; -import java.util.*; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.mllib.evaluation.RankingMetrics; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.Rating; - -// Read in the ratings data -public class Ranking { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); - JavaSparkContext sc = new JavaSparkContext(conf); - String path = "data/mllib/sample_movielens_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String line) { - String[] parts = line.split("::"); - return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); - } - } - ); - ratings.cache(); - - // Train an ALS model - final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); - - // Get top 10 recommendations for every user and scale ratings from 0 to 1 - JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); - JavaRDD> userRecsScaled = userRecs.map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 t) { - Rating[] scaledRatings = new Rating[t._2().length]; - for (int i = 0; i < scaledRatings.length; i++) { - double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); - scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); - } - return new Tuple2(t._1(), scaledRatings); - } - } - ); - JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); - - // Map ratings to 1 or 0, 1 indicating a movie that should be recommended - JavaRDD binarizedRatings = ratings.map( - new Function() { - public Rating call(Rating r) { - double binaryRating; - if (r.rating() > 0.0) { - binaryRating = 1.0; - } - else { - binaryRating = 0.0; - } - return new Rating(r.user(), r.product(), binaryRating); - } - } - ); - - // Group ratings by common user - JavaPairRDD> userMovies = binarizedRatings.groupBy( - new Function() { - public Object call(Rating r) { - return r.user(); - } - } - ); - - // Get true relevant documents from all user ratings - JavaPairRDD> userMoviesList = userMovies.mapValues( - new Function, List>() { - public List call(Iterable docs) { - List products = new ArrayList(); - for (Rating r : docs) { - if (r.rating() > 0.0) { - products.add(r.product()); - } - } - return products; - } - } - ); - - // Extract the product id from each recommendation - JavaPairRDD> userRecommendedList = userRecommended.mapValues( - new Function>() { - public List call(Rating[] docs) { - List products = new ArrayList(); - for (Rating r : docs) { - products.add(r.product()); - } - return products; - } - } - ); - JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); - - // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); - - // Precision and NDCG at k - Integer[] kVector = {1, 3, 5}; - for (Integer k : kVector) { - System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); - System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); - } - - // Mean average precision - System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); - - // Evaluate the model using numerical ratings and regression metrics - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Object>>() { - public Tuple2, Object> call(Rating r){ - return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - - // Create regression metrics object - RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); - - // Root mean squared error - System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R-squared = %f\n", regressionMetrics.r2()); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) and [`RankingMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, Rating -from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics - -# Read in the ratings data -lines = sc.textFile("data/mllib/sample_movielens_data.txt") - -def parseLine(line): - fields = line.split("::") - return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) -ratings = lines.map(lambda r: parseLine(r)) - -# Train a model on to predict user-product ratings -model = ALS.train(ratings, 10, 10, 0.01) - -# Get predicted ratings on all existing user-product pairs -testData = ratings.map(lambda p: (p.user, p.product)) -predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) - -ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) -scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) - -# Instantiate regression metrics to compare predicted and actual ratings -metrics = RegressionMetrics(scoreAndLabels) - -# Root mean sqaured error -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -{% endhighlight %} +{% include_example python/mllib/ranking_metrics_example.py %}
    @@ -1350,163 +582,21 @@ and evaluate the performance of the algorithm by several regression metrics.
    Refer to the [`RegressionMetrics` Scala docs](api/scala/index.html#org.apache.spark.mllib.evaluation.RegressionMetrics) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.regression.LinearRegressionModel -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.util.MLUtils - -// Load the data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() - -// Build the model -val numIterations = 100 -val model = LinearRegressionWithSGD.train(data, numIterations) - -// Get predictions -val valuesAndPreds = data.map{ point => - val prediction = model.predict(point.features) - (prediction, point.label) -} - -// Instantiate metrics object -val metrics = new RegressionMetrics(valuesAndPreds) - -// Squared error -println(s"MSE = ${metrics.meanSquaredError}") -println(s"RMSE = ${metrics.rootMeanSquaredError}") - -// R-squared -println(s"R-squared = ${metrics.r2}") - -// Mean absolute error -println(s"MAE = ${metrics.meanAbsoluteError}") - -// Explained variance -println(s"Explained variance = ${metrics.explainedVariance}") - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala %}
    Refer to the [`RegressionMetrics` Java docs](api/java/org/apache/spark/mllib/evaluation/RegressionMetrics.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.regression.LinearRegressionModel; -import org.apache.spark.mllib.regression.LinearRegressionWithSGD; -import org.apache.spark.mllib.evaluation.RegressionMetrics; -import org.apache.spark.SparkConf; - -public class LinearRegression { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_linear_regression_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public LabeledPoint call(String line) { - String[] parts = line.split(" "); - double[] v = new double[parts.length - 1]; - for (int i = 1; i < parts.length - 1; i++) - v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } - } - ); - parsedData.cache(); - - // Building the model - int numIterations = 100; - final LinearRegressionModel model = - LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); - - // Evaluate model on training examples and compute training error - JavaRDD> valuesAndPreds = parsedData.map( - new Function>() { - public Tuple2 call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2(prediction, point.label()); - } - } - ); - - // Instantiate metrics object - RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); - - // Squared error - System.out.format("MSE = %f\n", metrics.meanSquaredError()); - System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); - - // R-squared - System.out.format("R Squared = %f\n", metrics.r2()); - - // Mean absolute error - System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); - - // Explained variance - System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); - } -} - -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java %}
    Refer to the [`RegressionMetrics` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RegressionMetrics) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD -from pyspark.mllib.evaluation import RegressionMetrics -from pyspark.mllib.linalg import DenseVector - -# Load and parse the data -def parsePoint(line): - values = line.split() - return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) - -data = sc.textFile("data/mllib/sample_linear_regression_data.txt") -parsedData = data.map(parsePoint) - -# Build the model -model = LinearRegressionWithSGD.train(parsedData) - -# Get predictions -valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) - -# Instantiate metrics object -metrics = RegressionMetrics(valuesAndPreds) - -# Squared Error -print("MSE = %s" % metrics.meanSquaredError) -print("RMSE = %s" % metrics.rootMeanSquaredError) - -# R-squared -print("R-squared = %s" % metrics.r2) - -# Mean absolute error -print("MAE = %s" % metrics.meanAbsoluteError) - -# Explained variance -print("Explained variance = %s" % metrics.explainedVariance) - -{% endhighlight %} +{% include_example python/mllib/regression_metrics_example.py %}
    diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java new file mode 100644 index 000000000000..980a9108af53 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaBinaryClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = + data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call(Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java new file mode 100644 index 000000000000..b54e1ea3f2bc --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.rdd.RDD; +import org.apache.spark.SparkConf; +// $example off$ +import org.apache.spark.SparkContext; + +public class JavaMultiLabelClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision + (metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure + (metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java new file mode 100644 index 000000000000..21f628fb51b6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +// $example off$ +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class JavaMulticlassClassificationMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example"); + SparkContext sc = new SparkContext(conf); + // $example on$ + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision + (metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics + .labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure + (metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java new file mode 100644 index 000000000000..7c4c97e74681 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.*; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +// $example off$ +import org.apache.spark.SparkConf; + +public class JavaRankingMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Ranking Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double + .parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join + (userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r) { + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java new file mode 100644 index 000000000000..d2efc6bf9777 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRegressionMetricsExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRegressionMetricsExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Java Regression Metrics Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + // $example on$ + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), + numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "target/tmp/LogisticRegressionModel"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), + "target/tmp/LogisticRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py new file mode 100644 index 000000000000..437acb998acc --- /dev/null +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Binary Classification Metrics Example. +""" +from __future__ import print_function +import sys +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="BinaryClassificationMetricsExample") + sqlContext = SQLContext(sc) + # $example on$ + # Several of the methods available in scala are currently missing from pyspark + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = BinaryClassificationMetrics(predictionAndLabels) + + # Area under precision-recall curve + print("Area under PR = %s" % metrics.areaUnderPR) + + # Area under ROC curve + print("Area under ROC = %s" % metrics.areaUnderROC) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_class_metrics_example.py b/examples/src/main/python/mllib/multi_class_metrics_example.py new file mode 100644 index 000000000000..cd56b3c97c77 --- /dev/null +++ b/examples/src/main/python/mllib/multi_class_metrics_example.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiClassMetricsExample") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Load training data in LIBSVM format + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + # Split data into training (60%) and test (40%) + training, test = data.randomSplit([0.6, 0.4], seed=11L) + training.cache() + + # Run training algorithm to build the model + model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + + # Compute raw scores on the test set + predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + + # Instantiate metrics object + metrics = MulticlassMetrics(predictionAndLabels) + + # Overall statistics + precision = metrics.precision() + recall = metrics.recall() + f1Score = metrics.fMeasure() + print("Summary Stats") + print("Precision = %s" % precision) + print("Recall = %s" % recall) + print("F1 Score = %s" % f1Score) + + # Statistics by class + labels = data.map(lambda lp: lp.label).distinct().collect() + for label in sorted(labels): + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) + + # Weighted stats + print("Weighted recall = %s" % metrics.weightedRecall) + print("Weighted precision = %s" % metrics.weightedPrecision) + print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) + print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) + print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) + # $example off$ diff --git a/examples/src/main/python/mllib/multi_label_metrics_example.py b/examples/src/main/python/mllib/multi_label_metrics_example.py new file mode 100644 index 000000000000..960ade659737 --- /dev/null +++ b/examples/src/main/python/mllib/multi_label_metrics_example.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.evaluation import MultilabelMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="MultiLabelMetricsExample") + # $example on$ + scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + + # Instantiate metrics object + metrics = MultilabelMetrics(scoreAndLabels) + + # Summary stats + print("Recall = %s" % metrics.recall()) + print("Precision = %s" % metrics.precision()) + print("F1 measure = %s" % metrics.f1Measure()) + print("Accuracy = %s" % metrics.accuracy) + + # Individual label stats + labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() + for label in labels: + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) + + # Micro stats + print("Micro precision = %s" % metrics.microPrecision) + print("Micro recall = %s" % metrics.microRecall) + print("Micro F1 measure = %s" % metrics.microF1Measure) + + # Hamming loss + print("Hamming loss = %s" % metrics.hammingLoss) + + # Subset accuracy + print("Subset accuracy = %s" % metrics.subsetAccuracy) + # $example off$ diff --git a/examples/src/main/python/mllib/ranking_metrics_example.py b/examples/src/main/python/mllib/ranking_metrics_example.py new file mode 100644 index 000000000000..327791966c90 --- /dev/null +++ b/examples/src/main/python/mllib/ranking_metrics_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics +# $example off$ +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Ranking Metrics Example") + + # Several of the methods available in scala are currently missing from pyspark + # $example on$ + # Read in the ratings data + lines = sc.textFile("data/mllib/sample_movielens_data.txt") + + def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) + ratings = lines.map(lambda r: parseLine(r)) + + # Train a model on to predict user-product ratings + model = ALS.train(ratings, 10, 10, 0.01) + + # Get predicted ratings on all existing user-product pairs + testData = ratings.map(lambda p: (p.user, p.product)) + predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + + ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) + scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + + # Instantiate regression metrics to compare predicted and actual ratings + metrics = RegressionMetrics(scoreAndLabels) + + # Root mean sqaured error + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + # $example off$ diff --git a/examples/src/main/python/mllib/regression_metrics_example.py b/examples/src/main/python/mllib/regression_metrics_example.py new file mode 100644 index 000000000000..a3a83aafd7a1 --- /dev/null +++ b/examples/src/main/python/mllib/regression_metrics_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# $example on$ +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector +# $example off$ + +from pyspark import SparkContext + +if __name__ == "__main__": + sc = SparkContext(appName="Regression Metrics Example") + + # $example on$ + # Load and parse the data + def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), + DenseVector([float(x.split(':')[1]) for x in values[1:]])) + + data = sc.textFile("data/mllib/sample_linear_regression_data.txt") + parsedData = data.map(parsePoint) + + # Build the model + model = LinearRegressionWithSGD.train(parsedData) + + # Get predictions + valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + + # Instantiate metrics object + metrics = RegressionMetrics(valuesAndPreds) + + # Squared Error + print("MSE = %s" % metrics.meanSquaredError) + print("RMSE = %s" % metrics.rootMeanSquaredError) + + # R-squared + print("R-squared = %s" % metrics.r2) + + # Mean absolute error + print("MAE = %s" % metrics.meanAbsoluteError) + + # Explained variance + print("Explained variance = %s" % metrics.explainedVariance) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala new file mode 100644 index 000000000000..13a37827ab93 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object BinaryClassificationMetricsExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("BinaryClassificationMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new BinaryClassificationMetrics(predictionAndLabels) + + // Precision by threshold + val precision = metrics.precisionByThreshold + precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") + } + + // Recall by threshold + val recall = metrics.recallByThreshold + recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") + } + + // Precision-Recall Curve + val PRC = metrics.pr + + // F-measure + val f1Score = metrics.fMeasureByThreshold + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") + } + + val beta = 0.5 + val fScore = metrics.fMeasureByThreshold(beta) + f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") + } + + // AUPRC + val auPRC = metrics.areaUnderPR + println("Area under precision-recall curve = " + auPRC) + + // Compute thresholds used in ROC and PR curves + val thresholds = precision.map(_._1) + + // ROC Curve + val roc = metrics.roc + + // AUROC + val auROC = metrics.areaUnderROC + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala new file mode 100644 index 000000000000..4503c15360ad --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MultiLabelMetricsExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultiLabelMetricsExample") + val sc = new SparkContext(conf) + // $example on$ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array.empty[Double], Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + + // Instantiate metrics object + val metrics = new MultilabelMetrics(scoreAndLabels) + + // Summary stats + println(s"Recall = ${metrics.recall}") + println(s"Precision = ${metrics.precision}") + println(s"F1 measure = ${metrics.f1Measure}") + println(s"Accuracy = ${metrics.accuracy}") + + // Individual label stats + metrics.labels.foreach(label => + println(s"Class $label precision = ${metrics.precision(label)}")) + metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) + metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + + // Micro stats + println(s"Micro recall = ${metrics.microRecall}") + println(s"Micro precision = ${metrics.microPrecision}") + println(s"Micro F1 measure = ${metrics.microF1Measure}") + + // Hamming loss + println(s"Hamming loss = ${metrics.hammingLoss}") + + // Subset accuracy + println(s"Subset accuracy = ${metrics.subsetAccuracy}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala new file mode 100644 index 000000000000..090444924598 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkContext, SparkConf} + +object MulticlassMetricsExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MulticlassMetricsExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load training data in LIBSVM format + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + + // Split data into training (60%) and test (40%) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) + training.cache() + + // Run training algorithm to build the model + val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + + // Compute raw scores on the test set + val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) + } + + // Instantiate metrics object + val metrics = new MulticlassMetrics(predictionAndLabels) + + // Confusion matrix + println("Confusion matrix:") + println(metrics.confusionMatrix) + + // Overall Statistics + val precision = metrics.precision + val recall = metrics.recall // same as true positive rate + val f1Score = metrics.fMeasure + println("Summary Statistics") + println(s"Precision = $precision") + println(s"Recall = $recall") + println(s"F1 Score = $f1Score") + + // Precision by label + val labels = metrics.labels + labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) + } + + // Recall by label + labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) + } + + // False positive rate by label + labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) + } + + // F-measure by label + labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) + } + + // Weighted stats + println(s"Weighted precision: ${metrics.weightedPrecision}") + println(s"Weighted recall: ${metrics.weightedRecall}") + println(s"Weighted F1 score: ${metrics.weightedFMeasure}") + println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala new file mode 100644 index 000000000000..cffa03d5cc9f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} + +object RankingMetricsExample { + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("RankingMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Read in the ratings data + val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) + }.cache() + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + val binarizedRatings = ratings.map(r => Rating(r.user, r.product, + if (r.rating > 0) 1.0 else 0.0)).cache() + + // Summarize ratings + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + // Build the model + val numIterations = 10 + val rank = 10 + val lambda = 0.01 + val model = ALS.train(ratings, rank, numIterations, lambda) + + // Define a function to scale ratings from 0 to 1 + def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) + } + + // Get sorted top ten predictions for each user and then scale from [0, 1] + val userRecommended = model.recommendProductsForUsers(10).map { case (user, recs) => + (user, recs.map(scaledRating)) + } + + // Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document + // Compare with top ten most relevant documents + val userMovies = binarizedRatings.groupBy(_.user) + val relevantDocuments = userMovies.join(userRecommended).map { case (user, (actual, + predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) + } + + // Instantiate metrics object + val metrics = new RankingMetrics(relevantDocuments) + + // Precision at K + Array(1, 3, 5).foreach { k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") + } + + // Mean average precision + println(s"Mean average precision = ${metrics.meanAveragePrecision}") + + // Normalized discounted cumulative gain + Array(1, 3, 5).foreach { k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") + } + + // Get predictions for each data point + val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, + r.product), r.rating)) + val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) + val predictionsAndLabels = allPredictions.join(allRatings).map { case ((user, product), + (predicted, actual)) => + (predicted, actual) + } + + // Get the RMSE using regression metrics + val regressionMetrics = new RegressionMetrics(predictionsAndLabels) + println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${regressionMetrics.r2}") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala new file mode 100644 index 000000000000..47d44532521c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println + +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +object RegressionMetricsExample { + def main(args: Array[String]) : Unit = { + val conf = new SparkConf().setAppName("RegressionMetricsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + + // Build the model + val numIterations = 100 + val model = LinearRegressionWithSGD.train(data, numIterations) + + // Get predictions + val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) + } + + // Instantiate metrics object + val metrics = new RegressionMetrics(valuesAndPreds) + + // Squared error + println(s"MSE = ${metrics.meanSquaredError}") + println(s"RMSE = ${metrics.rootMeanSquaredError}") + + // R-squared + println(s"R-squared = ${metrics.r2}") + + // Mean absolute error + println(s"MAE = ${metrics.meanAbsoluteError}") + + // Explained variance + println(s"Explained variance = ${metrics.explainedVariance}") + // $example off$ + } +} +// scalastyle:on println + From 58b4e4f88a330135c4cec04a30d24ef91bc61d91 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 20 Nov 2015 15:30:53 -0800 Subject: [PATCH 833/896] [SPARK-11787][SPARK-11883][SQL][FOLLOW-UP] Cleanup for this patch. This mainly moves SqlNewHadoopRDD to the sql package. There is some state that is shared between core and I've left that in core. This allows some other associated minor cleanup. Author: Nong Li Closes #9845 from nongli/spark-11787. --- .../org/apache/spark/rdd/HadoopRDD.scala | 6 +- .../spark/rdd/SqlNewHadoopRDDState.scala | 41 +++++++++++++ .../sql/catalyst/expressions/UnsafeRow.java | 59 ++++++++++++++---- .../catalyst/expressions/InputFileName.scala | 6 +- .../parquet/UnsafeRowParquetRecordReader.java | 14 +++++ .../scala/org/apache/spark/sql/SQLConf.scala | 5 ++ .../datasources}/SqlNewHadoopRDD.scala | 60 +++++++------------ .../datasources/parquet/ParquetRelation.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 43 ++++++------- .../datasources/parquet/ParquetIOSuite.scala | 19 ++++++ 10 files changed, 175 insertions(+), 80 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala rename {core/src/main/scala/org/apache/spark/rdd => sql/core/src/main/scala/org/apache/spark/sql/execution/datasources}/SqlNewHadoopRDD.scala (86%) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 7db583468792..f37c95bedc0a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -215,8 +215,8 @@ class HadoopRDD[K, V]( // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -256,7 +256,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala new file mode 100644 index 000000000000..3f15fff79366 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. + * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD + */ +private[spark] object SqlNewHadoopRDDState { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 33769363a0ed..b6979d0c8297 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.*; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.OutputStream; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -26,12 +30,26 @@ import java.util.HashSet; import java.util.Set; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.CalendarIntervalType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.NullType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -39,9 +57,23 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DateType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.NullType; +import static org.apache.spark.sql.types.DataTypes.ShortType; +import static org.apache.spark.sql.types.DataTypes.TimestampType; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -116,11 +148,6 @@ public static boolean isMutable(DataType dt) { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -187,6 +214,12 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, numFields, sizeInBytes); } + + public void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + @Override public void setNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index d809877817a5..bf215783fc27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{DataType, StringType} @@ -37,13 +37,13 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDD.getInputFileName() + SqlNewHadoopRDDState.getInputFileName() } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 8a92e489ccb7..dade488ca281 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -108,6 +108,19 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas */ private static final int DEFAULT_VAR_LEN_SIZE = 32; + /** + * Tries to initialize the reader for this split. Returns true if this reader supports reading + * this split and false otherwise. + */ + public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) { + try { + initialize(inputSplit, taskAttemptContext); + return true; + } catch (Exception e) { + return false; + } + } + /** * Implementation of RecordReader API. */ @@ -326,6 +339,7 @@ private void decodeBinaryBatch(int col, int num) throws IOException { } else { rowWriters[n].write(col, bytes.array(), bytes.position(), len); } + rows[n].setNotNullAt(col); } else { rows[n].setNullAt(col); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f40e603cd193..5ef3a48c56a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -323,6 +323,11 @@ private[spark] object SQLConf { "option must be set in Hadoop Configuration. 2. This option overrides " + "\"spark.sql.sources.outputCommitterClass\".") + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( + key = "spark.sql.parquet.enableUnsafeRowRecordReader", + defaultValue = Some(true), + doc = "Enables using the custom ParquetUnsafeRowRecordReader.") + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "When true, enable filter pushdown for ORC files.") diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 4d176332b69c..56cb63d9eff2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -28,13 +30,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{Utils, SerializableConfiguration, ShutdownHookManager} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.{Partition => SparkPartition, _} -import scala.reflect.ClassTag - private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -61,13 +62,13 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sc : SparkContext, + sqlContext: SQLContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sc, Nil) + extends RDD[V](sqlContext.sparkContext, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -99,7 +100,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). protected val enableUnsafeRowParquetReader: Boolean = - sc.conf.getBoolean("spark.parquet.enableUnsafeRowRecordReader", true) + sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) @@ -120,8 +121,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { + theSplit: SparkPartition, + context: TaskContext): Iterator[V] = { val iter = new Iterator[V] { val split = theSplit.asInstanceOf[SqlNewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) @@ -132,8 +133,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // Sets the thread local variable for the file's name split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDD.unsetInputFileName() + case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDDState.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -163,15 +164,13 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( * TODO: plumb this through a different way? */ if (enableUnsafeRowParquetReader && - format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - // TODO: move this class to sql.execution and remove this. - reader = Utils.classForName( - "org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader") - .newInstance().asInstanceOf[RecordReader[Void, V]] - try { - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - } catch { - case e: Exception => reader = null + format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { + val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader() + if (!parquetReader.tryInitialize( + split.serializableHadoopSplit.value, hadoopAttemptContext)) { + parquetReader.close() + } else { + reader = parquetReader.asInstanceOf[RecordReader[Void, V]] } } @@ -217,7 +216,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private def close() { if (reader != null) { - SqlNewHadoopRDD.unsetInputFileName() + SqlNewHadoopRDDState.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -235,7 +234,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { @@ -276,23 +275,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } super.persist(storageLevel) } -} - -private[spark] object SqlNewHadoopRDD { - - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index cb0aab8cc0d0..fdd745f48e97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -319,7 +319,7 @@ private[sql] class ParquetRelation( Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, + sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c8028a5ef552..cc5aae03d551 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -337,29 +337,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - // Renable when we can toggle custom ParquetRecordReader on/off. The custom reader does - // not do row by row filtering (and we probably don't want to push that). - ignore("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + // The unsafe row RecordReader does not support row by row filtering so run it with it disabled. + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") - - // This is the source RDD without Spark-side filtering. - val childRDD = - df - .queryExecution - .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] - .child - .execute() - - // The result should be single row. - // When a filter is pushed to Parquet, Parquet can apply it to every row. - // So, we can check the number of rows returned from the Parquet - // to make sure our filter pushdown work. - assert(childRDD.count == 1) + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // This is the source RDD without Spark-side filtering. + val childRDD = + df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(childRDD.count == 1) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 177ab42f7767..0c5d4887ed79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -579,6 +579,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("null and non-null strings") { + // Create a dataset where the first values are NULL and then some non-null values. The + // number of non-nulls needs to be bigger than the ParquetReader batch size. + val data = sqlContext.range(200).map { i => + if (i.getLong(0) < 150) Row(None) + else Row("a") + } + val df = sqlContext.createDataFrame(data, StructType(StructField("col", StringType) :: Nil)) + assert(df.agg("col" -> "count").collect().head.getLong(0) == 50) + + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/data" + df.write.parquet(path) + + val df2 = sqlContext.read.parquet(path) + assert(df2.agg("col" -> "count").collect().head.getLong(0) == 50) + } + } + test("read dictionary encoded decimals written as INT32") { checkAnswer( // Decimal column in this file is encoded using plain dictionary From 968acf3bd9a502fcad15df3e53e359695ae702cc Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:36:30 -0800 Subject: [PATCH 834/896] [SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` :38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust Closes #9870 from marmbrus/dataset-repl-agg. --- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 27 +++---------------- .../apache/spark/sql/JavaDatasetSuite.java | 8 +++--- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 081aa03002cc..cbcccb11f14a 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 6de3dd626576..263f04910476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - /** - * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. - * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. - * - * @since 1.6.0 - */ + // This is here to prevent us from adding overloads that would be ambiguous. @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ce40dd856f67..f7249b8945c4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -404,11 +404,9 @@ public String call(Tuple2 value) throws Exception { grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), - expr("sum(_2)"), - count("*")) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( new Tuple4<>("a", 3, 3L, 2L), From 68ed046836975b492b594967256d3c7951b568a5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 15:38:04 -0800 Subject: [PATCH 835/896] [SPARK-11890][SQL] Fix compilation for Scala 2.11 Author: Michael Armbrust Closes #9871 from marmbrus/scala211-break. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 918050b531c0..4a4a62ed1a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - private def silentSchemaFor(tpe: `Type`): Schema = try { + protected def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - private def getClassNameFromType(tpe: `Type`): String = { + protected def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 47815878ad5e47e89bfbd57acb848be2ce67a4a5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 20 Nov 2015 16:02:03 -0800 Subject: [PATCH 836/896] [HOTFIX] Fix Java Dataset Tests --- .../test/java/test/org/apache/spark/sql/JavaDatasetSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f7249b8945c4..f32374b4c04d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -409,8 +409,8 @@ public String call(Tuple2 value) throws Exception { .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( - new Tuple4<>("a", 3, 3L, 2L), - new Tuple4<>("b", 3, 3L, 1L)), + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), agged2.collectAsList()); } From a2dce22e0a25922e2052318d32f32877b7c27ec2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 20 Nov 2015 16:51:47 -0800 Subject: [PATCH 837/896] Revert "[SPARK-11689][ML] Add user guide and example code for LDA under spark.ml" This reverts commit e359d5dcf5bd300213054ebeae9fe75c4f7eb9e7. --- docs/ml-clustering.md | 30 ------ docs/ml-guide.md | 3 +- docs/mllib-guide.md | 1 - .../spark/examples/ml/JavaLDAExample.java | 94 ------------------- .../apache/spark/examples/ml/LDAExample.scala | 77 --------------- 5 files changed, 1 insertion(+), 204 deletions(-) delete mode 100644 docs/ml-clustering.md delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java delete mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md deleted file mode 100644 index 1743ef43a6dd..000000000000 --- a/docs/ml-clustering.md +++ /dev/null @@ -1,30 +0,0 @@ ---- -layout: global -title: Clustering - ML -displayTitle: ML - Clustering ---- - -In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). - -## Latent Dirichlet allocation (LDA) - -`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, -and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by -`EMLDAOptimizer` to a `DistributedLDAModel` if needed. - -
    - -Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.LDA) for more details. - -
    -{% include_example scala/org/apache/spark/examples/ml/LDAExample.scala %} -
    - -
    - -Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) for more details. - -{% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %} -
    - -
    \ No newline at end of file diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 6f35b30c3d4d..be18a05361a1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -40,7 +40,6 @@ Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., provide class probabilities, and linear models provide model summaries. * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) @@ -951,4 +950,4 @@ model.transform(test) {% endhighlight %} - \ No newline at end of file + diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 54e35fcbb15a..91e50ccfecec 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -69,7 +69,6 @@ We list major functionality from both below, with links to detailed guides. concepts. It also contains sections on using algorithms within the Pipelines API, for example: * [Feature extraction, transformation, and selection](ml-features.html) -* [Clustering](ml-clustering.html) * [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java deleted file mode 100644 index b3a7d2eb2978..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.ml.clustering.LDA; -import org.apache.spark.ml.clustering.LDAModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -/** - * An example demonstrating LDA - * Run with - *
    - * bin/run-example ml.JavaLDAExample
    - * 
    - */ -public class JavaLDAExample { - - private static class ParseVector implements Function { - private static final Pattern separator = Pattern.compile(" "); - - @Override - public Row call(String line) { - String[] tok = separator.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - Vector[] points = {Vectors.dense(point)}; - return new GenericRow(points); - } - } - - public static void main(String[] args) { - - String inputFile = "data/mllib/sample_lda_data.txt"; - - // Parses the arguments - SparkConf conf = new SparkConf().setAppName("JavaLDAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); - - // Loads data - JavaRDD points = jsc.textFile(inputFile).map(new ParseVector()); - StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; - StructType schema = new StructType(fields); - DataFrame dataset = sqlContext.createDataFrame(points, schema); - - // Trains a LDA model - LDA lda = new LDA() - .setK(10) - .setMaxIter(10); - LDAModel model = lda.fit(dataset); - - System.out.println(model.logLikelihood(dataset)); - System.out.println(model.logPerplexity(dataset)); - - // Shows the result - DataFrame topics = model.describeTopics(3); - topics.show(false); - model.transform(dataset).show(false); - - jsc.stop(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala deleted file mode 100644 index 419ce3d87a6a..000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml - -// scalastyle:off println -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -// $example on$ -import org.apache.spark.ml.clustering.LDA -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} -// $example off$ - -/** - * An example demonstrating a LDA of ML pipeline. - * Run with - * {{{ - * bin/run-example ml.LDAExample - * }}} - */ -object LDAExample { - - final val FEATURES_COL = "features" - - def main(args: Array[String]): Unit = { - - val input = "data/mllib/sample_lda_data.txt" - // Creates a Spark context and a SQL context - val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // $example on$ - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) - - // Trains a LDA model - val lda = new LDA() - .setK(10) - .setMaxIter(10) - .setFeaturesCol(FEATURES_COL) - val model = lda.fit(dataset) - val transformed = model.transform(dataset) - - val ll = model.logLikelihood(dataset) - val lp = model.logPerplexity(dataset) - - // describeTopics - val topics = model.describeTopics(3) - - // Shows the result - topics.show(false) - transformed.show(false) - - // $example off$ - sc.stop() - } -} -// scalastyle:on println From 7d3f922c4ba76c4193f98234ae662065c39cdfb1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2015 23:31:19 -0800 Subject: [PATCH 838/896] [SPARK-11819][SQL][FOLLOW-UP] fix scala 2.11 build seems scala 2.11 doesn't support: define private methods in `trait xxx` and use it in `object xxx extend xxx`. Author: Wenchen Fan Closes #9879 from cloud-fan/follow. --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4a4a62ed1a46..476becec4dd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -670,14 +670,14 @@ trait ScalaReflection { * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return * `NullType` silently instead. */ - protected def silentSchemaFor(tpe: `Type`): Schema = try { + def silentSchemaFor(tpe: `Type`): Schema = try { schemaFor(tpe) } catch { case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns the full class name for a type. */ - protected def getClassNameFromType(tpe: `Type`): String = { + def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } From 54328b6d862fe62ae01bdd87df4798ceb9d506d6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:10:13 -0800 Subject: [PATCH 839/896] [SPARK-11900][SQL] Add since version for all encoders Author: Reynold Xin Closes #9881 from rxin/SPARK-11900. --- .../scala/org/apache/spark/sql/Encoder.scala | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 86bb53645903..5cb8edf64e87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -45,13 +45,52 @@ trait Encoder[T] extends Serializable { */ object Encoders { + /** + * An encoder for nullable boolean type. + * @since 1.6.0 + */ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * @since 1.6.0 + */ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * @since 1.6.0 + */ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * @since 1.6.0 + */ def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * @since 1.6.0 + */ def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * @since 1.6.0 + */ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * @since 1.6.0 + */ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * @since 1.6.0 + */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** @@ -59,6 +98,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) @@ -67,6 +108,8 @@ object Encoders { * This encoder maps T into a single byte array (binary) field. * * T must be publicly accessible. + * + * @since 1.6.0 */ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) @@ -77,6 +120,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) @@ -87,6 +132,8 @@ object Encoders { * Note that this is extremely inefficient and should only be used as the last resort. * * T must be publicly accessible. + * + * @since 1.6.0 */ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) @@ -120,12 +167,20 @@ object Encoders { ) } + /** + * An encoder for 2-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2]( e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) } + /** + * An encoder for 3-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3]( e1: Encoder[T1], e2: Encoder[T2], @@ -133,6 +188,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) } + /** + * An encoder for 4-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4]( e1: Encoder[T1], e2: Encoder[T2], @@ -141,6 +200,10 @@ object Encoders { ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) } + /** + * An encoder for 5-ary tuples. + * @since 1.6.0 + */ def tuple[T1, T2, T3, T4, T5]( e1: Encoder[T1], e2: Encoder[T2], From 596710268e29e8f624c3ba2fade08b66ec7084eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 00:54:18 -0800 Subject: [PATCH 840/896] [SPARK-11901][SQL] API audit for Aggregator. Author: Reynold Xin Closes #9882 from rxin/SPARK-11901. --- .../scala/org/apache/spark/sql/Dataset.scala | 1 - .../spark/sql/expressions/Aggregator.scala | 39 ++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bdcdc5d47cba..07647508421a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 72610e735f78..b0cd32b5f73e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} /** * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] @@ -32,55 +31,65 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} * case class Data(i: Int) * * val customSummer = new Aggregator[Data, Int, Int] { - * def zero = 0 - * def reduce(b: Int, a: Data) = b + a.i - * def present(r: Int) = r + * def zero: Int = 0 + * def reduce(b: Int, a: Data): Int = b + a.i + * def merge(b1: Int, b2: Int): Int = b1 + b2 + * def present(r: Int): Int = r * }.toColumn() * - * val ds: Dataset[Data] + * val ds: Dataset[Data] = ... * val aggregated = ds.select(customSummer) * }}} * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam A The input type for the aggregation. + * @tparam I The input type for the aggregation. * @tparam B The type of the intermediate value of the reduction. - * @tparam C The type of the final result. + * @tparam O The type of the final output result. + * + * @since 1.6.0 */ -abstract class Aggregator[-A, B, C] extends Serializable { +abstract class Aggregator[-I, B, O] extends Serializable { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + /** + * A zero value for this aggregation. Should satisfy the property that any b + zero = b. + * @since 1.6.0 + */ def zero: B /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. + * @since 1.6.0 */ - def reduce(b: B, a: A): B + def reduce(b: B, a: I): B /** - * Merge two intermediate values + * Merge two intermediate values. + * @since 1.6.0 */ def merge(b1: B, b2: B): B /** * Transform the output of the reduction. + * @since 1.6.0 */ - def finish(reduction: B): C + def finish(reduction: B): O /** * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] * operations. + * @since 1.6.0 */ def toColumn( implicit bEncoder: Encoder[B], - cEncoder: Encoder[C]): TypedColumn[A, C] = { + cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = new AggregateExpression( TypedAggregateExpression(this), Complete, false) - new TypedColumn[A, C](expr, encoderFor[C]) + new TypedColumn[I, O](expr, encoderFor[O]) } } From ff442bbcffd4f93cfcc2f76d160011e725d2fb3f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 21 Nov 2015 15:00:37 -0800 Subject: [PATCH 841/896] [SPARK-11899][SQL] API audit for GroupedDataset. 1. Renamed map to mapGroup, flatMap to flatMapGroup. 2. Renamed asKey -> keyAs. 3. Added more documentation. 4. Changed type parameter T to V on GroupedDataset. 5. Added since versions for all functions. Author: Reynold Xin Closes #9880 from rxin/SPARK-11899. --- .../api/java/function/MapGroupFunction.java | 2 +- .../scala/org/apache/spark/sql/Encoder.scala | 4 + .../sql/catalyst/JavaTypeInference.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 2 + .../org/apache/spark/sql/DataFrame.scala | 1 - .../org/apache/spark/sql/GroupedDataset.scala | 132 ++++++++++++++---- .../apache/spark/sql/JavaDatasetSuite.java | 8 +- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +- .../org/apache/spark/sql/DatasetSuite.scala | 20 +-- 9 files changed, 131 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java index 2935f9986a56..4f3f222e064b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -21,7 +21,7 @@ import java.util.Iterator; /** - * Base interface for a map function used in GroupedDataset's map function. + * Base interface for a map function used in GroupedDataset's mapGroup function. */ public interface MapGroupFunction extends Serializable { R call(K key, Iterator values) throws Exception; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 5cb8edf64e87..03aa25eda807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.types._ * * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. + * + * @since 1.6.0 */ trait Encoder[T] extends Serializable { @@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable { /** * Methods for creating encoders. + * + * @since 1.6.0 */ object Encoders { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 88a457f87ce4..7d4cfbe6faec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ /** * Type-inference utilities for POJOs and Java collections. */ -private [sql] object JavaTypeInference { +object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) @@ -53,7 +53,6 @@ private [sql] object JavaTypeInference { * @return (SQL data type, nullable) */ private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 82e9cd7f50a3..30c554a85e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -46,6 +46,8 @@ private[sql] object Column { * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. + * + * @since 1.6.0 */ class TypedColumn[-T, U]( expr: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7abcecaa2880..5586fc994b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -110,7 +110,6 @@ private[sql] object DataFrame { * @groupname action Actions * @since 1.3.0 */ -// TODO: Improve documentation. @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 263f04910476..7f43ce16901b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.Aggregator /** * :: Experimental :: @@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution * making this change to the class hierarchy would break some function signatures. As such, this * class should be considered a preview of the final API. Changes will be made to the interface * after Spark 1.6. + * + * @since 1.6.0 */ @Experimental -class GroupedDataset[K, T] private[sql]( +class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[T], + tEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * + * @since 1.6.0 */ - def asKey[L : Encoder]: GroupedDataset[L, T] = + def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], unresolvedTEncoder, @@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql]( /** * Returns a [[Dataset]] that contains each unique key. + * + * @since 1.6.0 */ def keys: Dataset[K] = { new Dataset[K]( @@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql]( * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( @@ -108,8 +121,25 @@ class GroupedDataset[K, T] private[sql]( logicalPlan)) } - def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -117,32 +147,62 @@ class GroupedDataset[K, T] private[sql]( * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. + * + * @since 1.6.0 */ - def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { - val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) - flatMap(func) + def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) + flatMapGroup(func) } - def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { - map((key, data) => f.call(key, data.asJava))(encoder) + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an [[Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 1.6.0 + */ + def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroup((key, data) => f.call(key, data.asJava))(encoder) } /** * Reduces the elements of each group of data using the specified binary function. * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 */ - def reduce(f: (T, T) => T): Dataset[(K, T)] = { - val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + def reduce(f: (V, V) => V): Dataset[(K, V)] = { + val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMap(func) + flatMapGroup(func) } - def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + * + * @since 1.6.0 + */ + def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { reduce(f.call _) } @@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql]( /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ - def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. + * + * @since 1.6.0 */ def agg[U1, U2, U3, U4]( - col1: TypedColumn[T, U1], - col2: TypedColumn[T, U2], - col3: TypedColumn[T, U3], - col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. + * + * @since 1.6.0 */ def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) @@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql]( * be passed the grouping key and 2 iterators containing all elements in the group from * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, @@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql]( other.logicalPlan)) } + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + * + * @since 1.6.0 + */ def cogroup[U, R]( other: GroupedDataset[K, U], - f: CoGroupFunction[K, T, U, R], + f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f32374b4c04d..cf335efdd23b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.map(new MapGroupFunction() { + Dataset mapped = grouped.mapGroup(new MapGroupFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,7 +183,7 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMap( + Dataset flatMapped = grouped.flatMapGroup( new FlatMapGroupFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { @@ -247,9 +247,9 @@ public void testGroupByColumn() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); GroupedDataset grouped = - ds.groupBy(length(col("value"))).asKey(Encoders.INT()); + ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.map( + Dataset mapped = grouped.mapGroup( new MapGroupFunction() { @Override public String call(Integer key, Iterator data) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 63b00975e4eb..d387710357be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.map { case (g, iter) => + val agged = grouped.mapGroup { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 89d964aa3e46..9da02550b39c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + val agged = grouped.flatMapGroup { case (g, iter) => + Iterator(g._1, iter.map(_._2).sum.toString) + } checkAnswer( agged, @@ -253,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -262,8 +264,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -272,8 +274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -282,8 +284,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 426004a9c9a864f90494d08601e6974709091a56 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 22 Nov 2015 10:36:47 -0800 Subject: [PATCH 842/896] [SPARK-11908][SQL] Add NullType support to RowEncoder JIRA: https://issues.apache.org/jira/browse/SPARK-11908 We should add NullType support to RowEncoder. Author: Liang-Chi Hsieh Closes #9891 from viirya/rowencoder-nulltype. --- .../org/apache/spark/sql/catalyst/encoders/RowEncoder.scala | 5 +++-- .../org/apache/spark/sql/catalyst/expressions/objects.scala | 3 +++ .../apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 4cda4824acdc..fa553e7c5324 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -48,7 +48,7 @@ object RowEncoder { private def extractorsFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject case udt: UserDefinedType[_] => @@ -143,6 +143,7 @@ object RowEncoder { case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) case udt: UserDefinedType[_] => ObjectType(udt.userClass) + case _: NullType => ObjectType(classOf[java.lang.Object]) } private def constructorFor(schema: StructType): Expression = { @@ -158,7 +159,7 @@ object RowEncoder { } private def constructorFor(input: Expression): Expression = input.dataType match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input case udt: UserDefinedType[_] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index ef7399e0196a..82317d338516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -369,6 +369,9 @@ case class MapObjects( private lazy val completeFunction = function(loopAttribute) private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case NullType => + val nullTypeClassName = NullType.getClass.getName + ".MODULE$" + (i: String) => s".get($i, $nullTypeClassName)" case IntegerType => (i: String) => s".getInt($i)" case LongType => (i: String) => s".getLong($i)" case FloatType => (i: String) => s".getFloat($i)" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 46c6e0d98d34..0ea51ece4bc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -80,11 +80,13 @@ class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) private val mapOfString = MapType(StringType, StringType) private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() + .add("null", NullType) .add("boolean", BooleanType) .add("byte", ByteType) .add("short", ShortType) @@ -101,6 +103,7 @@ class RowEncoderSuite extends SparkFunSuite { encodeDecodeTest( new StructType() + .add("arrayOfNull", arrayOfNull) .add("arrayOfString", arrayOfString) .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) From fe89c1817d668e46adf70d0896c42c22a547c76a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 22 Nov 2015 21:45:46 -0800 Subject: [PATCH 843/896] [SPARK-11895][ML] rename and refactor DatasetExample under mllib/examples We used the name `Dataset` to refer to `SchemaRDD` in 1.2 in ML pipelines and created this example file. Since `Dataset` has a new meaning in Spark 1.6, we should rename it to avoid confusion. This PR also removes support for dense format to simplify the example code. cc: yinxusen Author: Xiangrui Meng Closes #9873 from mengxr/SPARK-11895. --- .../DataFrameExample.scala} | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) rename examples/src/main/scala/org/apache/spark/examples/{mllib/DatasetExample.scala => ml/DataFrameExample.scala} (51%) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala similarity index 51% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala rename to examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index dc13f82488af..424f00158c2f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -16,7 +16,7 @@ */ // scalastyle:off println -package org.apache.spark.examples.mllib +package org.apache.spark.examples.ml import java.io.File @@ -24,25 +24,22 @@ import com.google.common.io.Files import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} /** - * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * ./bin/run-example ml.DataFrameExample [options] * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DatasetExample { +object DataFrameExample { - case class Params( - input: String = "data/mllib/sample_libsvm_data.txt", - dataFormat: String = "libsvm") extends AbstractParams[Params] + case class Params(input: String = "data/mllib/sample_libsvm_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -52,9 +49,6 @@ object DatasetExample { opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) - opt[String]("dataFormat") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(input = x)) checkConfig { params => success } @@ -69,55 +63,42 @@ object DatasetExample { def run(params: Params) { - val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val conf = new SparkConf().setAppName(s"DataFrameExample with $params") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // for implicit conversions // Load input data - val origData: RDD[LabeledPoint] = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) - } - println(s"Loaded ${origData.count()} instances from file: ${params.input}") - - // Convert input data to DataFrame explicitly. - val df: DataFrame = origData.toDF() - println(s"Inferred schema:\n${df.schema.prettyJson}") - println(s"Converted to DataFrame with ${df.count()} records") - - // Select columns - val labelsDf: DataFrame = df.select("label") - val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } - val numLabels = labels.count() - val meanLabel = labels.fold(0.0)(_ + _) / numLabels - println(s"Selected label column with average value $meanLabel") - - val featuresDf: DataFrame = df.select("features") - val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } + println(s"Loading LIBSVM file with UDT from ${params.input}.") + val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + println("Schema from LIBSVM:") + df.printSchema() + println(s"Loaded training data as a DataFrame with ${df.count()} records.") + + // Show statistical summary of labels. + val labelSummary = df.describe("label") + labelSummary.show() + + // Convert features column to an RDD of vectors. + val features = df.select("features").map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + // Save the records in a parquet file. val tmpDir = Files.createTempDir() tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) + // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.read.parquet(outputDir) - - println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } - val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), - (sum1, sum2) => sum1.merge(sum2)) - println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + val newDF = sqlContext.read.parquet(outputDir) + println(s"Schema from Parquet:") + newDF.printSchema() sc.stop() } - } // scalastyle:on println From a6fda0bfc16a13b28b1cecc96f1ff91363089144 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 22 Nov 2015 21:48:48 -0800 Subject: [PATCH 844/896] [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators I believe this works for general estimators within CrossValidator, including compound estimators. (See the complex unit test.) Added read/write for all 3 Evaluators as well. CC: mengxr yanboliang Author: Joseph K. Bradley Closes #9848 from jkbradley/cv-io. --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +-- .../BinaryClassificationEvaluator.scala | 11 +- .../MulticlassClassificationEvaluator.scala | 12 +- .../ml/evaluation/RegressionEvaluator.scala | 11 +- .../apache/spark/ml/recommendation/ALS.scala | 14 +- .../spark/ml/tuning/CrossValidator.scala | 229 +++++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 48 ++-- .../org/apache/spark/ml/PipelineSuite.scala | 4 +- .../BinaryClassificationEvaluatorSuite.scala | 13 +- ...lticlassClassificationEvaluatorSuite.scala | 13 +- .../evaluation/RegressionEvaluatorSuite.scala | 12 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 202 ++++++++++++++- 12 files changed, 522 insertions(+), 85 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 6f15b37abcb3..4b2b3f8489fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] { stages: Array[PipelineStage], sc: SparkContext, path: String): Unit = { - // Copied and edited from DefaultParamsWriter.saveMetadata - // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication - val uid = instance.uid - val cls = instance.getClass.getName val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) - val metadata = ("class" -> cls) ~ - ("timestamp" -> System.currentTimeMillis()) ~ - ("sparkVersion" -> sc.version) ~ - ("uid" -> uid) ~ - ("paramMap" -> jsonParams) - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) // Save stages val stagesDir = new Path(path, "stages").toString @@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] { implicit val format = DefaultFormats val stagesDir = new Path(path, "stages").toString - val stageUids: Array[String] = metadata.params match { - case JObject(pairs) => - if (pairs.length != 1) { - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException( - s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.") - } - pairs.head match { - case ("stageUids", jsonValue) => - jsonValue.extract[Seq[String]].toArray - case (paramName, jsonValue) => - // Should not happen unless file is corrupted or we have a bug. - throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" + - s" in metadata: ${metadata.metadataStr}") - } - case _ => - throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") - } + val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc) - val cls = Utils.classForName(stageMetadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath) + DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) } (metadata.uid, stages) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1fe3abaca81c..bfb70963b151 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol { + extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.4.1") override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): BinaryClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index df5f04ca5a8d..c44db0ec595e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.types.DoubleType @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object MulticlassClassificationEvaluator + extends DefaultParamsReadable[MulticlassClassificationEvaluator] { + + @Since("1.6.0") + override def load(path: String): MulticlassClassificationEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index ba012f444d3e..daaa174a086e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol { + extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.5.0") override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } + +@Since("1.6.0") +object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] { + + @Since("1.6.0") + override def load(path: String): RegressionEvaluator = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4d35177ad9b0..b798aa1fab76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.hadoop.fs.{FileSystem, Path} -import org.json4s.{DefaultFormats, JValue} +import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] { private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - val extraMetadata = render("rank" -> instance.rank) + val extraMetadata = "rank" -> instance.rank DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val userPath = new Path(path, "userFactors").toString instance.userFactors.write.format("parquet").save(userPath) @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] { override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats - val rank: Int = metadata.extraMetadata match { - case Some(m: JValue) => - (m \ "rank").extract[Int] - case None => - throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" + - s" ${metadata.metadataStr}") - } - + val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString val userFactors = sqlContext.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 77d9948ed86b..83a904837426 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,17 +18,24 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS +import org.apache.hadoop.fs.Path +import org.json4s.{JObject, DefaultFormats} +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.feature.RFormulaModel +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType + /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ @@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { */ @Experimental class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with Logging { + with CrossValidatorParams with MLWritable with Logging { def this() = this(Identifiable.randomUID("cv")) @@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } copied } + + // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types. + // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]]. + // However, this case should be unusual. + @Since("1.6.0") + override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this) +} + +@Since("1.6.0") +object CrossValidator extends MLReadable[CrossValidator] { + + @Since("1.6.0") + override def read: MLReader[CrossValidator] = new CrossValidatorReader + + @Since("1.6.0") + override def load(path: String): CrossValidator = super.load(path) + + private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = + SharedReadWrite.saveImpl(path, instance, sc) + } + + private class CrossValidatorReader extends MLReader[CrossValidator] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidator].getName + + override def load(path: String): CrossValidator = { + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + new CrossValidator(metadata.uid) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimatorParamMaps) + .setNumFolds(numFolds) + } + } + + private object CrossValidatorReader { + /** + * Examine the given estimator (which may be a compound estimator) and extract a mapping + * from UIDs to corresponding [[Params]] instances. + */ + def getUidMap(instance: Params): Map[String, Params] = { + val uidList = getUidMapImpl(instance) + val uidMap = uidList.toMap + if (uidList.size != uidMap.size) { + throw new RuntimeException("CrossValidator.load found a compound estimator with stages" + + s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}") + } + uidMap + } + + def getUidMapImpl(instance: Params): List[(String, Params)] = { + val subStages: Array[Params] = instance match { + case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] + case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] + case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) + case ovr: OneVsRestParams => + // TODO: SPARK-11892: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing type: ${ovr.getClass.getName}") + case rform: RFormulaModel => + // TODO: SPARK-11891: This case may require special handling. + throw new UnsupportedOperationException("CrossValidator write will fail because it" + + " cannot yet handle an estimator containing an RFormulaModel") + case _: Params => Array() + } + val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + List((instance.uid, instance)) ++ subStageMaps + } + } + + private[tuning] object SharedReadWrite { + + /** + * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable. + * This does not check [[CrossValidator.estimatorParamMaps]]. + */ + def validateParams(instance: ValidatorParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("CrossValidator write will fail " + + s" because it contains $name which does not implement Writable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + checkElement(instance.getEvaluator, "evaluator") + checkElement(instance.getEstimator, "estimator") + // Check to make sure all Params apply to this estimator. Throw an error if any do not. + // Extraneous Params would cause problems when loading the estimatorParamMaps. + val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance) + instance.getEstimatorParamMaps.foreach { case pMap: ParamMap => + pMap.toSeq.foreach { case ParamPair(p, v) => + require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" + + s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" + + s" Evaluator. An extraneous Param was found: $p") + } + } + } + + private[tuning] def saveImpl( + path: String, + instance: CrossValidatorParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + import org.json4s.JsonDSL._ + + val estimatorParamMapsJson = compact(render( + instance.getEstimatorParamMaps.map { case paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) + } + }.toSeq + )) + val jsonParams = List( + "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)), + "estimatorParamMaps" -> parse(estimatorParamMapsJson) + ) + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val evaluatorPath = new Path(path, "evaluator").toString + instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath) + val estimatorPath = new Path(path, "estimator").toString + instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath) + } + + private[tuning] def load[M <: Model[M]]( + path: String, + sc: SparkContext, + expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + + implicit val format = DefaultFormats + val evaluatorPath = new Path(path, "evaluator").toString + val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc) + val estimatorPath = new Path(path, "estimator").toString + val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc) + + val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator) + + val numFolds = (metadata.params \ "numFolds").extract[Int] + val estimatorParamMaps: Array[ParamMap] = + (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { case pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + (metadata, estimator, evaluator, estimatorParamMaps, numFolds) + } + } } /** @@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM * * @param bestModel The best model selected from k-fold cross validation. * @param avgMetrics Average cross-validation metrics for each paramMap in - * [[estimatorParamMaps]], in the corresponding order. + * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Experimental class CrossValidatorModel private[ml] ( override val uid: String, val bestModel: Model[_], val avgMetrics: Array[Double]) - extends Model[CrossValidatorModel] with CrossValidatorParams { + extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { override def validateParams(): Unit = { bestModel.validateParams() @@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] ( avgMetrics.clone()) copyValues(copied, extra).setParent(parent) } + + @Since("1.6.0") + override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) +} + +@Since("1.6.0") +object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + + import CrossValidator.SharedReadWrite + + @Since("1.6.0") + override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader + + @Since("1.6.0") + override def load(path: String): CrossValidatorModel = super.load(path) + + private[CrossValidatorModel] + class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + + SharedReadWrite.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + import org.json4s.JsonDSL._ + val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata)) + val bestModelPath = new Path(path, "bestModel").toString + instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + } + } + + private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[CrossValidatorModel].getName + + override def load(path: String): CrossValidatorModel = { + implicit val format = DefaultFormats + + val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) = + SharedReadWrite.load(path, sc, className) + val bestModelPath = new Path(path, "bestModel").toString + val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) + val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + cv.set(cv.estimator, estimator) + .set(cv.evaluator, evaluator) + .set(cv.estimatorParamMaps, estimatorParamMaps) + .set(cv.numFolds, numFolds) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ff9322dba122..8484b1f80106 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter { * - timestamp * - sparkVersion * - uid - * - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]]. + * - paramMap + * - (optionally, extra metadata) + * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. + * @param paramMap If given, this is saved in the "paramMap" field. + * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using + * [[org.apache.spark.ml.param.Param.jsonEncode()]]. */ def saveMetadata( instance: Params, path: String, sc: SparkContext, - extraMetadata: Option[JValue] = None): Unit = { + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] - val jsonParams = params.map { case ParamPair(p, v) => + val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) - }.toList - val metadata = ("class" -> cls) ~ + }.toList)) + val basicMetadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ - ("paramMap" -> jsonParams) ~ - ("extraMetadata" -> extraMetadata) + ("paramMap" -> jsonParams) + val metadata = extraMetadata match { + case Some(jObject) => + basicMetadata ~ jObject + case None => + basicMetadata + } val metadataPath = new Path(path, "metadata").toString val metadataJson = compact(render(metadata)) sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) @@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. * @param params paramMap, as a [[JValue]] - * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]] - * @param metadataStr Full metadata file String (for debugging) + * @param metadata All metadata, including the other fields + * @param metadataJson Full metadata file String (for debugging) */ case class Metadata( className: String, @@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader { timestamp: Long, sparkVersion: String, params: JValue, - extraMetadata: Option[JValue], - metadataStr: String) + metadata: JValue, + metadataJson: String) /** * Load metadata from file. @@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader { val timestamp = (metadata \ "timestamp").extract[Long] val sparkVersion = (metadata \ "sparkVersion").extract[String] val params = metadata \ "paramMap" - val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]] if (expectedClassName.nonEmpty) { require(className == expectedClassName, s"Error loading metadata: Expected class name" + s" $expectedClassName but found class name $className") } - Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr) + Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) } /** @@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader { } case _ => throw new IllegalArgumentException( - s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") } } + + /** + * Load a [[Params]] instance from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstance[T](path: String, sc: SparkContext): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 12aba6bc6dbe..8c8676745636 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml -import java.io.File - import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index def869fe6677..a535c1218ecf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class BinaryClassificationEvaluatorSuite extends SparkFunSuite { +class BinaryClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new BinaryClassificationEvaluator) } + + test("read/write") { + val evaluator = new BinaryClassificationEvaluator() + .setRawPredictionCol("myRawPrediction") + .setLabelCol("myLabel") + .setMetricName("areaUnderPR") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 6d8412b0b370..7ee65975d22f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { +class MulticlassClassificationEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new MulticlassClassificationEvaluator) } + + test("read/write") { + val evaluator = new MulticlassClassificationEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("recall") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da32393..60886bf77d2f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegressionEvaluatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new RegressionEvaluator) @@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext evaluator.setMetricName("mae") assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } + + test("read/write") { + val evaluator = new RegressionEvaluator() + .setPredictionCol("myPrediction") + .setLabelCol("myLabel") + .setMetricName("r2") + testDefaultReadWrite(evaluator) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index cbe09292a033..dd6366050c02 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.{Pipeline, Estimator, Model} +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.{ParamPair, ParamMap} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { +class CrossValidatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { } test("validateParams should check estimatorParamMaps") { - import CrossValidatorSuite._ + import CrossValidatorSuite.{MyEstimator, MyEvaluator} val est = new MyEstimator("est") val eval = new MyEvaluator @@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { cv.validateParams() } } + + test("read/write: CrossValidator with simple estimator") { + val lr = new LogisticRegression().setMaxIter(3) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + } + + test("read/write: CrossValidator with complex estimator") { + // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]] + val lrEvaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + + val lr = new LogisticRegression().setMaxIter(3) + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val lrcv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(lrEvaluator) + .setEstimatorParamMaps(lrParamMaps) + + val hashingTF = new HashingTF() + val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv)) + val paramMaps = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 20)) + .addGrid(lr.elasticNetParam, Array(0.0, 1.0)) + .build() + val evaluator = new BinaryClassificationEvaluator() + + val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(evaluator) + .setNumFolds(20) + .setEstimatorParamMaps(paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.getEstimator match { + case pipeline2: Pipeline => + assert(pipeline.uid === pipeline2.uid) + pipeline2.getStages match { + case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) => + assert(hashingTF.uid === hashingTF2.uid) + lrcv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + case other => + throw new AssertionError(s"Loaded internal CrossValidator expected to be" + + s" LogisticRegression but found type ${other.getClass.getName}") + } + assert(lrcv.uid === lrcv2.uid) + assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + assert(lrEvaluator.uid === lrcv2.getEvaluator.uid) + CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps) + case other => + throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" + + " but found: " + other.map(_.getClass.getName).mkString(", ")) + } + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" CrossValidator but found ${other.getClass.getName}") + } + } + + test("read/write: CrossValidator fails for extraneous Param") { + val lr = new LogisticRegression() + val lr2 = new LogisticRegression() + val evaluator = new BinaryClassificationEvaluator() + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .addGrid(lr2.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidator() + .setEstimator(lr) + .setEvaluator(evaluator) + .setEstimatorParamMaps(paramMaps) + withClue("CrossValidator.write failed to catch extraneous Param error") { + intercept[IllegalArgumentException] { + cv.write + } + } + } + + test("read/write: CrossValidatorModel") { + val lr = new LogisticRegression() + .setThreshold(0.6) + val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2) + .setThreshold(0.6) + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderPR") // not default metric + val paramMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.2)) + .build() + val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6)) + cv.set(cv.estimator, lr) + .set(cv.evaluator, evaluator) + .set(cv.numFolds, 20) + .set(cv.estimatorParamMaps, paramMaps) + + val cv2 = testDefaultReadWrite(cv, testParams = false) + + assert(cv.uid === cv2.uid) + assert(cv.getNumFolds === cv2.getNumFolds) + + assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) + val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] + assert(evaluator.uid === evaluator2.uid) + assert(evaluator.getMetricName === evaluator2.getMetricName) + + cv2.getEstimator match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getThreshold === lr2.getThreshold) + case other => + throw new AssertionError(s"Loaded CrossValidator expected estimator of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) + + cv2.bestModel match { + case lrModel2: LogisticRegressionModel => + assert(lrModel.uid === lrModel2.uid) + assert(lrModel.getThreshold === lrModel2.getThreshold) + assert(lrModel.coefficients === lrModel2.coefficients) + assert(lrModel.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + assert(cv.avgMetrics === cv2.avgMetrics) + } } -object CrossValidatorSuite { +object CrossValidatorSuite extends SparkFunSuite { + + /** + * Assert sequences of estimatorParamMaps are identical. + * Params must be simple types comparable with `===`. + */ + def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { + assert(pMaps.length === pMaps2.length) + pMaps.zip(pMaps2).foreach { case (pMap, pMap2) => + assert(pMap.size === pMap2.size) + pMap.toSeq.foreach { case ParamPair(p, v) => + assert(pMap2.contains(p)) + assert(pMap2(p) === v) + } + } + } abstract class MyModel extends Model[MyModel] From fc4b792d287095d70379a51f117c225d8d857078 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Sun, 22 Nov 2015 21:51:42 -0800 Subject: [PATCH 845/896] [SPARK-11835] Adds a sidebar menu to MLlib's documentation This PR adds a sidebar menu when browsing the user guide of MLlib. It uses a YAML file to describe the structure of the documentation. It should be trivial to adapt this to the other projects. ![screen shot 2015-11-18 at 4 46 12 pm](https://cloud.githubusercontent.com/assets/7594753/11259591/a55173f4-8e17-11e5-9340-0aed79d66262.png) Author: Timothy Hunter Closes #9826 from thunterdb/spark-11835. --- docs/_data/menu-ml.yaml | 10 ++++ docs/_data/menu-mllib.yaml | 75 +++++++++++++++++++++++++ docs/_includes/nav-left-wrapper-ml.html | 8 +++ docs/_includes/nav-left.html | 17 ++++++ docs/_layouts/global.html | 24 +++++--- docs/css/main.css | 37 ++++++++++++ 6 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 docs/_data/menu-ml.yaml create mode 100644 docs/_data/menu-mllib.yaml create mode 100644 docs/_includes/nav-left-wrapper-ml.html create mode 100644 docs/_includes/nav-left.html diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml new file mode 100644 index 000000000000..dff3d33bf4ed --- /dev/null +++ b/docs/_data/menu-ml.yaml @@ -0,0 +1,10 @@ +- text: Feature extraction, transformation, and selection + url: ml-features.html +- text: Decision trees for classification and regression + url: ml-decision-tree.html +- text: Ensembles + url: ml-ensembles.html +- text: Linear methods with elastic-net regularization + url: ml-linear-methods.html +- text: Multilayer perceptron classifier + url: ml-ann.html diff --git a/docs/_data/menu-mllib.yaml b/docs/_data/menu-mllib.yaml new file mode 100644 index 000000000000..12d22abd5282 --- /dev/null +++ b/docs/_data/menu-mllib.yaml @@ -0,0 +1,75 @@ +- text: Data types + url: mllib-data-types.html +- text: Basic statistics + url: mllib-statistics.html + subitems: + - text: Summary statistics + url: mllib-statistics.html#summary-statistics + - text: Correlations + url: mllib-statistics.html#correlations + - text: Stratified sampling + url: mllib-statistics.html#stratified-sampling + - text: Hypothesis testing + url: mllib-statistics.html#hypothesis-testing + - text: Random data generation + url: mllib-statistics.html#random-data-generation +- text: Classification and regression + url: mllib-classification-regression.html + subitems: + - text: Linear models (SVMs, logistic regression, linear regression) + url: mllib-linear-methods.html + - text: Naive Bayes + url: mllib-naive-bayes.html + - text: decision trees + url: mllib-decision-tree.html + - text: ensembles of trees (Random Forests and Gradient-Boosted Trees) + url: mllib-ensembles.html + - text: isotonic regression + url: mllib-isotonic-regression.html +- text: Collaborative filtering + url: mllib-collaborative-filtering.html + subitems: + - text: alternating least squares (ALS) + url: mllib-collaborative-filtering.html#collaborative-filtering +- text: Clustering + url: mllib-clustering.html + subitems: + - text: k-means + url: mllib-clustering.html#k-means + - text: Gaussian mixture + url: mllib-clustering.html#gaussian-mixture + - text: power iteration clustering (PIC) + url: mllib-clustering.html#power-iteration-clustering-pic + - text: latent Dirichlet allocation (LDA) + url: mllib-clustering.html#latent-dirichlet-allocation-lda + - text: streaming k-means + url: mllib-clustering.html#streaming-k-means +- text: Dimensionality reduction + url: mllib-dimensionality-reduction.html + subitems: + - text: singular value decomposition (SVD) + url: mllib-dimensionality-reduction.html#singular-value-decomposition-svd + - text: principal component analysis (PCA) + url: mllib-dimensionality-reduction.html#principal-component-analysis-pca +- text: Feature extraction and transformation + url: mllib-feature-extraction.html +- text: Frequent pattern mining + url: mllib-frequent-pattern-mining.html + subitems: + - text: FP-growth + url: mllib-frequent-pattern-mining.html#fp-growth + - text: association rules + url: mllib-frequent-pattern-mining.html#association-rules + - text: PrefixSpan + url: mllib-frequent-pattern-mining.html#prefix-span +- text: Evaluation metrics + url: mllib-evaluation-metrics.html +- text: PMML model export + url: mllib-pmml-model-export.html +- text: Optimization (developer) + url: mllib-optimization.html + subitems: + - text: stochastic gradient descent + url: mllib-optimization.html#stochastic-gradient-descent-sgd + - text: limited-memory BFGS (L-BFGS) + url: mllib-optimization.html#limited-memory-bfgs-l-bfgs diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html new file mode 100644 index 000000000000..0103e890cc21 --- /dev/null +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -0,0 +1,8 @@ +
    +
    +

    spark.ml package

    + {% include nav-left.html nav=include.nav-ml %} +

    spark.mllib package

    + {% include nav-left.html nav=include.nav-mllib %} +
    +
    \ No newline at end of file diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html new file mode 100644 index 000000000000..73176f413255 --- /dev/null +++ b/docs/_includes/nav-left.html @@ -0,0 +1,17 @@ +{% assign navurl = page.url | remove: 'index.html' %} + diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 467ff7a03fb7..1b09e2221e17 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -124,16 +124,24 @@ -
    - {% if page.displayTitle %} -

    {{ page.displayTitle }}

    - {% else %} -

    {{ page.title }}

    - {% endif %} +
    - {{ content }} + {% if page.url contains "/ml" %} + {% include nav-left-wrapper-ml.html nav-mllib=site.data.menu-mllib nav-ml=site.data.menu-ml %} + {% endif %} -
    + +
    + {% if page.displayTitle %} +

    {{ page.displayTitle }}

    + {% else %} +

    {{ page.title }}

    + {% endif %} + + {{ content }} + +
    +
    diff --git a/docs/css/main.css b/docs/css/main.css index d770173be101..356b324d6303 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -39,8 +39,18 @@ margin-left: 10px; } +body .container-wrapper { + position: absolute; + width: 100%; + display: flex; +} + body #content { + position: relative; + line-height: 1.6; /* Inspired by Github's wiki style */ + background-color: white; + padding-left: 15px; } .title { @@ -155,3 +165,30 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { * AnchorJS (anchor links when hovering over headers) */ a.anchorjs-link:hover { text-decoration: none; } + + +/** + * The left navigation bar. + */ +.left-menu-wrapper { + position: absolute; + height: 100%; + + width: 256px; + margin-top: -20px; + padding-top: 20px; + background-color: #F0F8FC; +} + +.left-menu { + position: fixed; + max-width: 350px; + + padding-right: 10px; + width: 256px; +} + +.left-menu h3 { + margin-left: 10px; + line-height: 30px; +} \ No newline at end of file From d9cf9c21fc6b1aa22e68d66760afd42c4e1c18b8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 22 Nov 2015 21:56:07 -0800 Subject: [PATCH 846/896] [SPARK-11912][ML] ml.feature.PCA minor refactor Like [SPARK-11852](https://issues.apache.org/jira/browse/SPARK-11852), ```k``` is params and we should save it under ```metadata/``` rather than both under ```data/``` and ```metadata/```. Refactor the constructor of ```ml.feature.PCAModel``` to take only ```pc``` but construct ```mllib.feature.PCAModel``` inside ```transform```. Author: Yanbo Liang Closes #9897 from yanboliang/spark-11912. --- .../org/apache/spark/ml/feature/PCA.scala | 23 +++++++------- .../apache/spark/ml/feature/PCASuite.scala | 31 ++++++++----------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 32d7afee6e73..aa88cb03d23c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) - copyValues(new PCAModel(uid, pcaModel).setParent(this)) + copyValues(new PCAModel(uid, pcaModel.pc).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] { /** * :: Experimental :: * Model fitted by [[PCA]]. + * + * @param pc A principal components Matrix. Each column is one principal component. */ @Experimental class PCAModel private[ml] ( override val uid: String, - pcaModel: feature.PCAModel) + val pc: DenseMatrix) extends Model[PCAModel] with PCAParams with MLWritable { import PCAModel._ - /** a principal components Matrix. Each column is one principal component. */ - val pc: DenseMatrix = pcaModel.pc - /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -124,6 +123,7 @@ class PCAModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + val pcaModel = new feature.PCAModel($(k), pc) val pcaOp = udf { pcaModel.transform _ } dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } @@ -139,7 +139,7 @@ class PCAModel private[ml] ( } override def copy(extra: ParamMap): PCAModel = { - val copied = new PCAModel(uid, pcaModel) + val copied = new PCAModel(uid, pc) copyValues(copied, extra).setParent(parent) } @@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] { private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter { - private case class Data(k: Int, pc: DenseMatrix) + private case class Data(pc: DenseMatrix) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.getK, instance.pc) + val data = Data(instance.pc) val dataPath = new Path(path, "data").toString sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath) - .select("k", "pc") + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath) + .select("pc") .head() - val oldModel = new feature.PCAModel(k, pc) - val model = new PCAModel(metadata.uid, oldModel) + val model = new PCAModel(metadata.uid, pc) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 5a21cd20ceed..edab21e6c307 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] - val model = new PCAModel("pca", new OldPCAModel(2, mat)) + val model = new PCAModel("pca", mat) ParamsSuite.checkParams(model) } @@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } } - test("read/write") { + test("PCA read/write") { + val t = new PCA() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setK(3) + testDefaultReadWrite(t) + } - def checkModelData(model1: PCAModel, model2: PCAModel): Unit = { - assert(model1.pc === model2.pc) - } - val allParams: Map[String, Any] = Map( - "k" -> 3, - "inputCol" -> "features", - "outputCol" -> "pca_features" - ) - val data = Seq( - (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))), - (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), - (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - ) - val df = sqlContext.createDataFrame(data).toDF("id", "features") - val pca = new PCA().setK(3) - testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData) + test("PCAModel read/write") { + val instance = new PCAModel("myPCAModel", + Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.pc === instance.pc) } } From 4be360d4ee6cdb4d06306feca38ddef5212608cf Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 22 Nov 2015 22:05:01 -0800 Subject: [PATCH 847/896] [SPARK-11902][ML] Unhandled case in VectorAssembler#transform There is an unhandled case in the transform method of VectorAssembler if one of the input columns doesn't have one of the supported type DoubleType, NumericType, BooleanType or VectorUDT. So, if you try to transform a column of StringType you get a cryptic "scala.MatchError: StringType". This PR aims to fix this, throwing a SparkException when dealing with an unknown column type. Author: BenFradet Closes #9885 from BenFradet/SPARK-11902. --- .../org/apache/spark/ml/feature/VectorAssembler.scala | 2 ++ .../spark/ml/feature/VectorAssemblerSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 0feec0549852..801096fed27b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -84,6 +84,8 @@ class VectorAssembler(override val uid: String) val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) Array.fill(numAttrs)(NumericAttribute.defaultAttr) } + case otherType => + throw new SparkException(s"VectorAssembler does not support the $otherType type") } } val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index fb21ab6b9bf2..9c1c00f41ab1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -69,6 +69,17 @@ class VectorAssemblerSuite } } + test("transform should throw an exception in case of unsupported type") { + val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val assembler = new VectorAssembler() + .setInputCols(Array("a", "b", "c")) + .setOutputCol("features") + val thrown = intercept[SparkException] { + assembler.transform(df) + } + assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + } + test("ML attributes") { val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) From 94ce65dfcbba1fe3a1fc9d8002c37d9cd1a11336 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Mon, 23 Nov 2015 08:53:40 -0800 Subject: [PATCH 848/896] [SPARK-11628][SQL] support column datatype of char(x) to recognize HiveChar Can someone review my code to make sure I'm not missing anything? Thanks! Author: Xiu Guo Author: Xiu Guo Closes #9612 from xguo27/SPARK-11628. --- .../sql/catalyst/util/DataTypeParser.scala | 6 ++++- .../catalyst/util/DataTypeParserSuite.scala | 8 ++++-- .../spark/sql/sources/TableScanSuite.scala | 5 ++++ .../spark/sql/hive/HiveInspectors.scala | 25 ++++++++++++++++--- .../apache/spark/sql/hive/TableReader.scala | 3 +++ .../spark/sql/hive/client/HiveShim.scala | 3 ++- 6 files changed, 43 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala index 2b83651f9086..515c071c283b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeParser.scala @@ -52,7 +52,8 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | - varchar + varchar | + char protected lazy val fixedDecimalType: Parser[DataType] = ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { @@ -60,6 +61,9 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { DecimalType(precision.toInt, scale.toInt) } + protected lazy val char: Parser[DataType] = + "(?i)char".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType + protected lazy val varchar: Parser[DataType] = "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala index 1e3409a9db6e..bebf70896547 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DataTypeParserSuite.scala @@ -49,7 +49,9 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) + checkDataType("ChaR(5)", StringType) checkDataType("varchAr(20)", StringType) + checkDataType("cHaR(27)", StringType) checkDataType("BINARY", BinaryType) checkDataType("array", ArrayType(DoubleType, true)) @@ -83,7 +85,8 @@ class DataTypeParserSuite extends SparkFunSuite { |struct< | struct:struct, | MAP:Map, - | arrAy:Array> + | arrAy:Array, + | anotherArray:Array> """.stripMargin, StructType( StructField("struct", @@ -91,7 +94,8 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: - StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) + StructField("arrAy", ArrayType(DoubleType, true), true) :: + StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) // A column name can be a reserved word in our DDL parser and SqlParser. checkDataType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 12af8068c398..26c1ff520406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -85,6 +85,7 @@ case class AllDataTypesScan( Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -115,6 +116,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", + s"char_$i", Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -154,6 +156,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { |dateField dAte, |timestampField tiMestamp, |varcharField varchaR(12), + |charField ChaR(18), |arrayFieldSimple Array, |arrayFieldComplex Array>>, |mapFieldSimple MAP, @@ -207,6 +210,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: StructField("varcharField", StringType, true) :: + StructField("charField", StringType, true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,6 +252,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { | dateField, | timestampField, | varcharField, + | charField, | arrayFieldSimple, | arrayFieldComplex, | mapFieldSimple, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 36f0708f9da3..95b57d6ad124 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ -import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} @@ -61,6 +61,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Type * Java Boxed Primitives: * org.apache.hadoop.hive.common.type.HiveVarchar + * org.apache.hadoop.hive.common.type.HiveChar * java.lang.String * java.lang.Integer * java.lang.Boolean @@ -75,6 +76,7 @@ import org.apache.spark.unsafe.types.UTF8String * java.sql.Timestamp * Writables: * org.apache.hadoop.hive.serde2.io.HiveVarcharWritable + * org.apache.hadoop.hive.serde2.io.HiveCharWritable * org.apache.hadoop.io.Text * org.apache.hadoop.io.IntWritable * org.apache.hadoop.hive.serde2.io.DoubleWritable @@ -93,7 +95,8 @@ import org.apache.spark.unsafe.types.UTF8String * Struct: Object[] / java.util.List / java POJO * Union: class StandardUnion { byte tag; Object object } * - * NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type. + * NOTICE: HiveVarchar/HiveChar is not supported by catalyst, it will be simply considered as + * String type. * * * 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data @@ -137,6 +140,7 @@ import org.apache.spark.unsafe.types.UTF8String * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector + * WritableConstantHiveCharObjectInspector * WritableConstantHiveDecimalObjectInspector * WritableConstantTimestampObjectInspector * WritableConstantIntObjectInspector @@ -259,6 +263,8 @@ private[hive] trait HiveInspectors { UTF8String.fromString(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue) + case poi: WritableConstantHiveCharObjectInspector => + UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -303,11 +309,15 @@ private[hive] trait HiveInspectors { case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector case pi: PrimitiveObjectInspector => pi match { - // We think HiveVarchar is also a String + // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) case hvoi: HiveVarcharObjectInspector => UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) + case hvoi: HiveCharObjectInspector if hvoi.preferWritable() => + UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue) + case hvoi: HiveCharObjectInspector => + UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => UTF8String.fromString(x.getPrimitiveWritableObject(data).toString) case x: StringObjectInspector => @@ -377,6 +387,15 @@ private[hive] trait HiveInspectors { null } + case _: JavaHiveCharObjectInspector => + (o: Any) => + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveChar(s, s.size) + } else { + null + } + case _: JavaHiveDecimalObjectInspector => (o: Any) => if (o != null) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 69f481c49a65..70ee02823eeb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -382,6 +382,9 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: HiveCharObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 48bbb21e6c1d..346840079b85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -321,7 +321,8 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. val varcharKeys = table.getPartitionKeys.asScala - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet filters.collect { From 1a5baaa6517872b9a4fd6cd41c4b2cf1e390f6d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:13:59 -0800 Subject: [PATCH 849/896] [SPARK-11894][SQL] fix isNull for GetInternalRowField We should use `InternalRow.isNullAt` to check if the field is null before calling `InternalRow.getXXX` Thanks gatorsmile who discovered this bug. Author: Wenchen Fan Closes #9904 from cloud-fan/null. --- .../sql/catalyst/expressions/objects.scala | 23 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 82317d338516..4a1f419f0ad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -236,11 +236,6 @@ case class NewInstance( } if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -531,15 +526,15 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val row = child.gen(ctx) - s""" - ${row.code} - final boolean ${ev.isNull} = ${row.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; - } - """ + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + }) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 9da02550b39c..cc8e4325fd2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq((JavaData(1), 1L), (JavaData(2), 1L))) } - ignore("Java encoder self join") { + test("Java encoder self join") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() assert(ds.joinWith(ds, lit(true)).collect().toSet == @@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(1)), (JavaData(2), JavaData(2)))) } + + test("SPARK-11894: Incorrect results are returned when using null") { + val nullInt = null.asInstanceOf[java.lang.Integer] + val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() + + checkAnswer( + ds1.joinWith(ds2, lit(true)), + ((nullInt, "1"), (nullInt, "1")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), + ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) + } } From f2996e0d12eeb989b1bfa51a3f6fa54ce1ed4fca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:15:40 -0800 Subject: [PATCH 850/896] [SPARK-11921][SQL] fix `nullable` of encoder schema Author: Wenchen Fan Closes #9906 from cloud-fan/nullable. --- .../catalyst/encoders/ExpressionEncoder.scala | 15 +++++++- .../encoders/ExpressionEncoderSuite.scala | 38 ++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6eeba1442c1f..7bc9aed0b204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -54,8 +54,13 @@ object ExpressionEncoder { val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] + val schema = ScalaReflection.schemaFor[T] match { + case ScalaReflection.Schema(s: StructType, _) => s + case ScalaReflection.Schema(dt, nullable) => new StructType().add("value", dt, nullable) + } + new ExpressionEncoder[T]( - toRowExpression.dataType, + schema, flat, toRowExpression.flatten, fromRowExpression, @@ -71,7 +76,13 @@ object ExpressionEncoder { encoders.foreach(_.assertUnresolved()) val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + case (e, i) => + val (dataType, nullable) = if (e.flat) { + e.schema.head.dataType -> e.schema.head.nullable + } else { + e.schema -> true + } + StructField(s"_${i + 1}", dataType, nullable) }) val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 76459b34a484..d6ca138672ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.ArrayType +import org.apache.spark.sql.types.{StructType, ArrayType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -238,6 +238,42 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("nullable of encoder schema") { + def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { + assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) + } + + // test for flat encoders + checkNullable[Int](false) + checkNullable[Option[Int]](true) + checkNullable[java.lang.Integer](true) + checkNullable[String](true) + + // test for product encoders + checkNullable[(String, Int)](true, false) + checkNullable[(Int, java.lang.Long)](false, true) + + // test for nested product encoders + { + val schema = ExpressionEncoder[(Int, (String, Int))].schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + + // test for tupled encoders + { + val schema = ExpressionEncoder.tuple( + ExpressionEncoder[Int], + ExpressionEncoder[(String, Int)]).schema + assert(schema(0).nullable === false) + assert(schema(1).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](0).nullable === true) + assert(schema(1).dataType.asInstanceOf[StructType](1).nullable === false) + } + } + private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( From 946b406519af58c79041217e6f93854b6cf80acd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 23 Nov 2015 10:39:33 -0800 Subject: [PATCH 851/896] [SPARK-11913][SQL] support typed aggregate with complex buffer schema Author: Wenchen Fan Closes #9898 from cloud-fan/agg. --- .../aggregate/TypedAggregateExpression.scala | 25 +++++++---- .../spark/sql/DatasetAggregatorSuite.scala | 41 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6ce41aaf01e2..a9719128a626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -23,9 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -46,14 +45,12 @@ object TypedAggregateExpression { /** * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has * the following limitations: - * - It assumes the aggregator reduces and returns a single column of type `long`. - * - It might only work when there is a single aggregator in the first column. * - It assumes the aggregator has a zero, `0`. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - bEncoder: ExpressionEncoder[Any], // Should be bound. + unresolvedBEncoder: ExpressionEncoder[Any], cEncoder: ExpressionEncoder[Any], children: Seq[Attribute], mutableAggBufferOffset: Int, @@ -80,10 +77,14 @@ case class TypedAggregateExpression( override lazy val inputTypes: Seq[DataType] = Nil - override val aggBufferSchema: StructType = bEncoder.schema + override val aggBufferSchema: StructType = unresolvedBEncoder.schema override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + val bEncoder = unresolvedBEncoder + .resolve(aggBufferAttributes, OuterScopes.outerScopes) + .bind(aggBufferAttributes) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -93,12 +94,18 @@ case class TypedAggregateExpression( lazy val boundA = aEncoder.get private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - // todo: need a more neat way to assign the value. var i = 0 while (i < aggBufferAttributes.length) { + val offset = mutableAggBufferOffset + i aggBufferSchema(i).dataType match { - case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) - case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) + case ByteType => buffer.setByte(offset, value.getByte(i)) + case ShortType => buffer.setShort(offset, value.getShort(i)) + case IntegerType => buffer.setInt(offset, value.getInt(i)) + case LongType => buffer.setLong(offset, value.getLong(i)) + case FloatType => buffer.setFloat(offset, value.getFloat(i)) + case DoubleType => buffer.setDouble(offset, value.getDouble(i)) + case other => buffer.update(offset, value.get(i, other)) } i += 1 } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 937758979001..19dce5d1e2f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L } case class AggData(a: Int, b: String) -object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { +object ClassInputAgg extends Aggregator[AggData, Int, Int] { /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 @@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { override def merge(b1: Int, b2: Int): Int = b1 + b2 } +object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: (Int, AggData) = 0 -> AggData(0, "0") + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: (Int, AggData)): Int = reduction._1 + + /** + * Merge two intermediate values + */ + override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = + (b1._1 + b2._1, b1._2) +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupBy(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } + + test("typed aggregation: complex input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ComplexBufferAgg.toColumn), + 2 + ) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), + (1.5, 2)) + + checkAnswer( + ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + ("one", 1), ("two", 1)) + } } From 5fd86e4fc2e06d2403ca538ae417580c93b69e06 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 23 Nov 2015 10:41:17 -0800 Subject: [PATCH 852/896] [SPARK-7173][YARN] Add label expression support for application master Add label expression support for AM to restrict it runs on the specific set of nodes. I tested it locally and works fine. sryza and vanzin please help to review, thanks a lot. Author: jerryshao Closes #9800 from jerryshao/SPARK-7173. --- docs/running-on-yarn.md | 9 +++++++ .../org/apache/spark/deploy/yarn/Client.scala | 26 ++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index db6bfa69ee0f..925a1e0ba6fc 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -326,6 +326,15 @@ If you need a reference to the proper location to put log files in the YARN so t Otherwise, the client process will exit after submission. + + spark.yarn.am.nodeLabelExpression + (none) + + A YARN node label expression that restricts the set of nodes AM will be scheduled on. + Only versions of YARN greater than or equal to 2.6 support node label expressions, so when + running against earlier versions, this property will be ignored. + + spark.yarn.executor.nodeLabelExpression (none) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ba799884f568..a77a3e2420e2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -225,7 +225,31 @@ private[spark] class Client( val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) - appContext.setResource(capability) + + if (sparkConf.contains("spark.yarn.am.nodeLabelExpression")) { + try { + val amRequest = Records.newRecord(classOf[ResourceRequest]) + amRequest.setResourceName(ResourceRequest.ANY) + amRequest.setPriority(Priority.newInstance(0)) + amRequest.setCapability(capability) + amRequest.setNumContainers(1) + val amLabelExpression = sparkConf.get("spark.yarn.am.nodeLabelExpression") + val method = amRequest.getClass.getMethod("setNodeLabelExpression", classOf[String]) + method.invoke(amRequest, amLabelExpression) + + val setResourceRequestMethod = + appContext.getClass.getMethod("setAMContainerResourceRequest", classOf[ResourceRequest]) + setResourceRequestMethod.invoke(appContext, amRequest) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.nodeLabelExpression because the version " + + "of YARN does not support it") + appContext.setResource(capability) + } + } else { + appContext.setResource(capability) + } + appContext } From 5231cd5acaae69d735ba3209531705cc222f3cfb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 10:45:23 -0800 Subject: [PATCH 853/896] [SPARK-11762][NETWORK] Account for active streams when couting outstanding requests. This way the timeout handling code can correctly close "hung" channels that are processing streams. Author: Marcelo Vanzin Closes #9747 from vanzin/SPARK-11762. --- .../network/client/StreamInterceptor.java | 12 ++++++++- .../client/TransportResponseHandler.java | 15 +++++++++-- .../TransportResponseHandlerSuite.java | 27 +++++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index 02230a00e69f..88ba3ccebdf2 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -30,13 +30,19 @@ */ class StreamInterceptor implements TransportFrameDecoder.Interceptor { + private final TransportResponseHandler handler; private final String streamId; private final long byteCount; private final StreamCallback callback; private volatile long bytesRead; - StreamInterceptor(String streamId, long byteCount, StreamCallback callback) { + StreamInterceptor( + TransportResponseHandler handler, + String streamId, + long byteCount, + StreamCallback callback) { + this.handler = handler; this.streamId = streamId; this.byteCount = byteCount; this.callback = callback; @@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor { @Override public void exceptionCaught(Throwable cause) throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, cause); } @Override public void channelInactive() throws Exception { + handler.deactivateStream(); callback.onFailure(streamId, new ClosedChannelException()); } @@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception { RuntimeException re = new IllegalStateException(String.format( "Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead)); callback.onFailure(streamId, re); + handler.deactivateStream(); throw re; } else if (bytesRead == byteCount) { + handler.deactivateStream(); callback.onComplete(streamId); } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index ed3f36af5804..cc88991b588c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; private final Queue streamCallbacks; + private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ private final AtomicLong timeOfLastRequestNs; @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) { } public void addStreamCallback(StreamCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(callback); } + @VisibleForTesting + public void deactivateStream() { + streamActive = false; + } + /** * Fire the failure callback for all outstanding requests. This is called when we have an * uncaught exception or pre-mature connection termination. @@ -177,14 +185,16 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount, + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); frameDecoder.setInterceptor(interceptor); + streamActive = true; } catch (Exception e) { logger.error("Error installing stream handler.", e); + deactivateStream(); } } else { logger.error("Could not find callback for StreamResponse."); @@ -208,7 +218,8 @@ public void handle(ResponseMessage message) { /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size(); + return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 17a03ebe88a9..30144f4a9fc7 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; @@ -28,12 +29,16 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.StreamFailure; +import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.util.TransportFrameDecoder; public class TransportResponseHandlerSuite { @Test @@ -112,4 +117,26 @@ public void handleFailedRPC() { verify(callback, times(1)).onFailure((Throwable) any()); assertEquals(0, handler.numOutstandingRequests()); } + + @Test + public void testActiveStreams() { + Channel c = new LocalChannel(); + c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + TransportResponseHandler handler = new TransportResponseHandler(c); + + StreamResponse response = new StreamResponse("stream", 1234L, null); + StreamCallback cb = mock(StreamCallback.class); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(response); + assertEquals(1, handler.numOutstandingRequests()); + handler.deactivateStream(); + assertEquals(0, handler.numOutstandingRequests()); + + StreamFailure failure = new StreamFailure("stream", "uh-oh"); + handler.addStreamCallback(cb); + assertEquals(1, handler.numOutstandingRequests()); + handler.handle(failure); + assertEquals(0, handler.numOutstandingRequests()); + } } From 98d7ec7df4bb115dbd84cb9acd744b6c8abfebd5 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Nov 2015 11:51:29 -0800 Subject: [PATCH 854/896] [SPARK-11920][ML][DOC] ML LinearRegression should use correct dataset in examples and user guide doc ML ```LinearRegression``` use ```data/mllib/sample_libsvm_data.txt``` as dataset in examples and user guide doc, but it's actually classification dataset rather than regression dataset. We should use ```data/mllib/sample_linear_regression_data.txt``` instead. The deeper causes is that ```LinearRegression``` with "normal" solver can not solve this dataset correctly, may be due to the ill condition and unreasonable label. This issue has been reported at [SPARK-11918](https://issues.apache.org/jira/browse/SPARK-11918). It will confuse users if they run the example code but get exception, so we should make this change which can clearly illustrate the usage of ```LinearRegression``` algorithm. Author: Yanbo Liang Closes #9905 from yanboliang/spark-11920. --- .../examples/ml/JavaLinearRegressionWithElasticNetExample.java | 2 +- .../src/main/python/ml/linear_regression_with_elastic_net.py | 3 ++- .../examples/ml/LinearRegressionWithElasticNetExample.scala | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java index 593f8fb3e9fe..4ad7676c8d32 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java @@ -37,7 +37,7 @@ public static void main(String[] args) { // $example on$ // Load training data DataFrame training = sqlContext.read().format("libsvm") - .load("data/mllib/sample_libsvm_data.txt"); + .load("data/mllib/sample_linear_regression_data.txt"); LinearRegression lr = new LinearRegression() .setMaxIter(10) diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index b0278276330c..a4cd40cf2672 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -29,7 +29,8 @@ # $example on$ # Load training data - training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + training = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 5a51ece6f9ba..22c824cea84d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -33,7 +33,8 @@ object LinearRegressionWithElasticNetExample { // $example on$ // Load training data - val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = sqlCtx.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() .setMaxIter(10) From f6dcc6e96ad3f88563d717d5b6c45394b44db747 Mon Sep 17 00:00:00 2001 From: Mortada Mehyar Date: Mon, 23 Nov 2015 12:03:15 -0800 Subject: [PATCH 855/896] [SPARK-11837][EC2] python3 compatibility for launching ec2 m3 instances this currently breaks for python3 because `string` module doesn't have `letters` anymore, instead `ascii_letters` should be used Author: Mortada Mehyar Closes #9797 from mortada/python3_fix. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9327e21e43db..9fd652a3df4c 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -595,7 +595,7 @@ def launch_cluster(conn, opts, cluster_name): dev = BlockDeviceType() dev.ephemeral_name = 'ephemeral%d' % i # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.letters[i + 1] + name = '/dev/sd' + string.ascii_letters[i + 1] block_map[name] = dev # Launch slaves From 1b6e938be836786bac542fa430580248161e5403 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 13:19:10 -0800 Subject: [PATCH 856/896] [SPARK-4424] Remove spark.driver.allowMultipleContexts override in tests This patch removes `spark.driver.allowMultipleContexts=true` from our test configuration. The multiple SparkContexts check was originally disabled because certain tests suites in SQL needed to create multiple contexts. As far as I know, this configuration change is no longer necessary, so we should remove it in order to make it easier to find test cleanup bugs. Author: Josh Rosen Closes #9865 from JoshRosen/SPARK-4424. --- pom.xml | 2 -- project/SparkBuild.scala | 1 - 2 files changed, 3 deletions(-) diff --git a/pom.xml b/pom.xml index ad849112ce76..234fd5dea1a6 100644 --- a/pom.xml +++ b/pom.xml @@ -1958,7 +1958,6 @@ false false false - true true src @@ -1997,7 +1996,6 @@ 1 false false - true true __not_used__ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 67724c4e9e41..f575f0012d59 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -632,7 +632,6 @@ object TestSettings { javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", From 1d9120201012213edb1971a09e0849336dbb9415 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 23 Nov 2015 13:44:30 -0800 Subject: [PATCH 857/896] [SPARK-11836][SQL] udf/cast should not create new SQLContext They should use the existing SQLContext. Author: Davies Liu Closes #9914 from davies/create_udf. --- python/pyspark/sql/column.py | 7 ++++--- python/pyspark/sql/functions.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 9ca8e1f264cf..81fd4e782628 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -346,9 +346,10 @@ def cast(self, dataType): if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) + from pyspark.sql import SQLContext + sc = SparkContext.getOrCreate() + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) else: raise TypeError("unexpected type: %s" % type(dataType)) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c3da513c1389..a1ca723bbd7a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1457,14 +1457,15 @@ def __init__(self, func, returnType, name=None): self._judf = self._create_judf(name) def _create_judf(self, name): + from pyspark.sql import SQLContext f, returnType = self.func, self.returnType # put them in closure `func` func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) - sc = SparkContext._active_spark_context + sc = SparkContext.getOrCreate() pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(self.returnType.json()) + ctx = SQLContext.getOrCreate(sc) + jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes, From 242be7daed9b01d19794bb2cf1ac421fe5ab7262 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 23 Nov 2015 13:46:34 -0800 Subject: [PATCH 858/896] [SPARK-11910][STREAMING][DOCS] Update twitter4j dependency version Author: Luciano Resende Closes #9892 from lresende/SPARK-11910. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 96b36b7a7320..ed6b28c28213 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -723,7 +723,7 @@ Some of these advanced sources are as follows. - **Kinesis:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kinesis Client Library 1.2.1. See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j to get the public stream of tweets using [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by Twitter4J library. You can either get the public stream, or get the filtered stream based on a From 7cfa4c6bc36d97e459d4adee7b03d537d63c337e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:51:43 -0800 Subject: [PATCH 859/896] [SPARK-11865][NETWORK] Avoid returning inactive client in TransportClientFactory. There's a very narrow race here where it would be possible for the timeout handler to close a channel after the client factory verified that the channel was still active. This change makes sure the client is marked as being recently in use so that the timeout handler does not close it until a new timeout cycle elapses. Author: Marcelo Vanzin Closes #9853 from vanzin/SPARK-11865. --- .../spark/network/client/TransportClient.java | 9 ++++- .../client/TransportClientFactory.java | 15 ++++++-- .../client/TransportResponseHandler.java | 9 +++-- .../server/TransportChannelHandler.java | 36 ++++++++++++------- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index a0ba223e340a..876fcd846791 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -73,10 +73,12 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { this.channel = Preconditions.checkNotNull(channel); this.handler = Preconditions.checkNotNull(handler); + this.timedOut = false; } public Channel getChannel() { @@ -84,7 +86,7 @@ public Channel getChannel() { } public boolean isActive() { - return channel.isOpen() || channel.isActive(); + return !timedOut && (channel.isOpen() || channel.isActive()); } public SocketAddress getSocketAddress() { @@ -263,6 +265,11 @@ public void onFailure(Throwable e) { } } + /** Mark this channel as having timed out. */ + public void timeOut() { + this.timedOut = true; + } + @Override public void close() { // close is a local operation and should finish with milliseconds; timeout just to be safe diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 42a4f664e697..659c47160c7b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -136,8 +136,19 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); - return cachedClient; + // Make sure that the channel will not timeout by updating the last use time of the + // handler. Then check that the client is still alive, in case it timed out before + // this code was able to update things. + TransportChannelHandler handler = cachedClient.getChannel().pipeline() + .get(TransportChannelHandler.class); + synchronized (handler) { + handler.getResponseHandler().updateTimeOfLastRequest(); + } + + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } } // If we reach here, we don't have an existing connection open. Let's create a new one. diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index cc88991b588c..be181e066082 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -71,7 +71,7 @@ public TransportResponseHandler(Channel channel) { } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); } @@ -80,7 +80,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { - timeOfLastRequestNs.set(System.nanoTime()); + updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -227,4 +227,9 @@ public long getTimeOfLastRequestNs() { return timeOfLastRequestNs.get(); } + /** Updates the time of the last request to the current system time. */ + public void updateTimeOfLastRequest() { + timeOfLastRequestNs.set(System.nanoTime()); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f8fcd1c3d7d7..29d688a67578 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -116,20 +116,32 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc // there are outstanding requests, we also do a secondary consistency check to ensure // there's no race between the idle timeout and incrementing the numOutstandingRequests // (see SPARK-7003). - boolean isActuallyOverdue = - System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { - if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); - } else if (closeIdleConnections) { - // While CloseIdleConnections is enable, we also close idle connection - ctx.close(); + // + // To avoid a race between TransportClientFactory.createClient() and this code which could + // result in an inactive client being returned, this needs to run in a synchronized block. + synchronized (this) { + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + client.timeOut(); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + client.timeOut(); + ctx.close(); + } } } } } + + public TransportResponseHandler getResponseHandler() { + return responseHandler; + } + } From c2467dadae8ce44010a912ee91c429310f8add65 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 23 Nov 2015 13:54:19 -0800 Subject: [PATCH 860/896] [SPARK-11140][CORE] Transfer files using network lib when using NettyRpcEnv. This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. Author: Marcelo Vanzin Closes #9530 from vanzin/SPARK-11140. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 14 -- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 ++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 60 +++++++- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 138 ++++++++++++++++-- .../spark/rpc/netty/NettyStreamManager.scala | 63 ++++++++ .../scala/org/apache/spark/util/Utils.scala | 9 ++ .../org/apache/spark/rpc/RpcEnvSuite.scala | 39 ++++- .../rpc/netty/NettyRpcHandlerSuite.scala | 10 +- docs/configuration.md | 2 + docs/security.md | 5 +- .../launcher/AbstractCommandBuilder.java | 2 +- .../client/TransportClientFactory.java | 6 +- .../server/TransportChannelHandler.java | 1 + 14 files changed, 356 insertions(+), 47 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index af4456c05b0a..b153a7b08e59 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1630,7 +1630,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1644,7 +1644,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1655,7 +1655,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 88df27f733f2..84230e32a446 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -66,7 +66,6 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, @@ -91,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -367,17 +365,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -422,7 +409,6 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, memoryManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf7..3d7d281b0dd6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,6 +17,9 @@ package org.apache.spark.rpc +import java.io.File +import java.nio.channels.ReadableByteChannel + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} @@ -132,8 +135,51 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. + * + * @param uri URI with location of the file. + */ + def openChannel(uri: String): ReadableByteChannel + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 059a7e10ec12..94dbec593c31 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc.akka +import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +32,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -41,7 +43,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * remove Akka from the dependencies. */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -64,6 +69,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -223,6 +230,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -241,6 +249,52 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def openChannel(uri: String): ReadableByteChannel = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -249,7 +303,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3ce359868039..68701f609f77 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable @@ -45,27 +46,39 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( + private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) + + /** + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. + * + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. + */ + @volatile private var fileDownloadFactory: TransportClientFactory = _ + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool @@ -292,6 +305,9 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + if (fileDownloadFactory != null) { + fileDownloadFactory.close() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -300,6 +316,96 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def openChannel(uri: String): ReadableByteChannel = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) + try { + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) + client.stream(parsedUri.getPath(), callback) + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e + } + + source + } + + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." + val clone = conf.clone() + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } + } + + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) + } + } + fileDownloadFactory.createClient(host, port) + } + + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { + + @volatile private var error: Throwable = _ + + def setError(e: Throwable): Unit = error = e + + override def read(dst: ByteBuffer): Int = { + if (error != null) { + throw error + } + source.read(dst) + } + + override def close(): Unit = source.close() + + override def isOpen(): Boolean = source.isOpen() + + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) + } + } + + override def onComplete(streamId: String): Unit = { + sink.close() + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -420,7 +526,7 @@ private[netty] class NettyRpcEndpointRef( override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -471,7 +577,9 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { // TODO: Can we add connection callback (channel registered) to the underlying framework? // A variable to track whether we should dispatch the RemoteProcessConnected message. @@ -498,7 +606,7 @@ private[netty] class NettyRpcHandler( dispatcher.postRemoteMessage(messageToDispatch, callback) } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] @@ -516,8 +624,8 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 000000000000..eb1d2604fb23 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + } + + require(file != null, s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1b3acb8ef7f5..af632349c9ca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.nio.channels.Channels import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection @@ -535,6 +536,14 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2f55006420ce..2b664c6313ef 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.rpc -import java.io.NotSerializableException +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -25,10 +27,14 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. @@ -40,12 +46,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -713,6 +724,30 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val conf = new SparkConf() + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val jarUri = env.fileServer.addJar(jar) + + val destDir = Utils.createTempDir() + val destFile = new File(destDir, file.getName()) + val destJar = new File(destDir, jar.getName()) + + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) + Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) + + assert(Files.equal(file, destFile)) + assert(Files.equal(jar, destJar)) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b6..ccca795683da 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,17 +25,19 @@ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) @@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) diff --git a/docs/configuration.md b/docs/configuration.md index c496146e3ed6..4de202d7f763 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1020,6 +1020,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the Akka RPC backend. @@ -1027,6 +1028,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend. diff --git a/docs/security.md b/docs/security.md index 177109415180..e1af221d446b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -149,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -157,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47f..55fe156cf665 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 659c47160c7b..61bafc838004 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -170,8 +170,10 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } /** - * Create a completely new {@link TransportClient} to the given remote host / port - * But this connection is not pooled. + * Create a completely new {@link TransportClient} to the given remote host / port. + * This connection is not pooled. + * + * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) throws IOException { diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 29d688a67578..3164e0067903 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -138,6 +138,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } } } + ctx.fireUserEventTriggered(evt); } public TransportResponseHandler getResponseHandler() { From 9db5f601facfdaba6e4333a6b2d2e4a9f009c788 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 23 Nov 2015 16:33:26 -0800 Subject: [PATCH 861/896] [SPARK-9866][SQL] Speed up VersionsSuite by using persistent Ivy cache This patch attempts to speed up VersionsSuite by storing fetched Hive JARs in an Ivy cache that persists across tests runs. If `SPARK_VERSIONS_SUITE_IVY_PATH` is set, that path will be used for the cache; if it is not set, VersionsSuite will create a temporary Ivy cache which is deleted after the test completes. Author: Josh Rosen Closes #9624 from JoshRosen/SPARK-9866. --- .../apache/spark/sql/hive/client/VersionsSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index c6d034a23a1c..7bc13bc60d30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -36,10 +36,12 @@ import org.apache.spark.util.Utils @ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { - // Do not use a temp path here to speed up subsequent executions of the unit test during - // development. - private val ivyPath = Some( - new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + // In order to speed up test execution during development or in Jenkins, you can specify the path + // of an existing Ivy cache: + private val ivyPath: Option[String] = { + sys.env.get("SPARK_VERSIONS_SUITE_IVY_PATH").orElse( + Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) + } private def buildConf() = { lazy val warehousePath = Utils.createTempDir() From 105745645b12afbbc2a350518cb5853a88944183 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 23 Nov 2015 17:11:51 -0800 Subject: [PATCH 862/896] [SPARK-10560][PYSPARK][MLLIB][DOCS] Make StreamingLogisticRegressionWithSGD Python API equal to Scala one This is to bring the API documentation of StreamingLogisticReressionWithSGD and StreamingLinearRegressionWithSGC in line with the Scala versions. -Fixed the algorithm descriptions -Added default values to parameter descriptions -Changed StreamingLogisticRegressionWithSGD regParam to default to 0, as in the Scala version Author: Bryan Cutler Closes #9141 from BryanCutler/StreamingLogisticRegressionWithSGD-python-api-sync. --- python/pyspark/mllib/classification.py | 37 +++++++++++++++++--------- python/pyspark/mllib/regression.py | 32 ++++++++++++++-------- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index aab4015ba80f..9e6f17ef6e94 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -652,21 +652,34 @@ def train(cls, data, lambda_=1.0): @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a batch of data. - - The weights obtained at the end of training a stream are used as initial - weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Number of iterations run for each batch of data. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param regParam: L2 Regularization parameter. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a logistic regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream. + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param regParam: + L2 Regularization parameter. + (default: 0.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 6f00d1df209c..13b3397501c0 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -734,17 +734,27 @@ def predictOnValues(self, dstream): @inherit_doc class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LinearRegression with SGD on a batch of data. - - The problem minimized is (1 / n_samples) * (y - weights'X)**2. - After training on a batch of data, the weights obtained at the end of - training are used as initial weights for the next batch. - - :param stepSize: Step size for each iteration of gradient descent. - :param numIterations: Total number of iterations run. - :param miniBatchFraction: Fraction of data on which SGD is run for each - iteration. - :param convergenceTol: A condition which decides iteration termination. + Train or predict a linear regression model on streaming data. Training uses + Stochastic Gradient Descent to update the model based on each new batch of + incoming data from a DStream (see `LinearRegressionWithSGD` for model equation). + + Each batch of data is assumed to be an RDD of LabeledPoints. + The number of data points per batch can vary, but the number + of features must be constant. An initial weight + vector must be provided. + + :param stepSize: + Step size for each iteration of gradient descent. + (default: 0.1) + :param numIterations: + Number of iterations run for each batch of data. + (default: 50) + :param miniBatchFraction: + Fraction of each batch of data to use for updates. + (default: 1.0) + :param convergenceTol: + Value used to determine when to terminate iterations. + (default: 0.001) .. versionadded:: 1.5.0 """ From 026ea2eab1f3cde270e8a6391d002915f3e1c6e5 Mon Sep 17 00:00:00 2001 From: Stephen Samuel Date: Mon, 23 Nov 2015 19:52:12 -0800 Subject: [PATCH 863/896] Updated sql programming guide to include jdbc fetch size Author: Stephen Samuel Closes #9377 from sksamuel/master. --- docs/sql-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e347754055e7..d7b205c2fa0d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1820,6 +1820,7 @@ the Data Sources API. The following options are supported: register itself with the JDBC subsystem. + partitionColumn, lowerBound, upperBound, numPartitions @@ -1831,6 +1832,13 @@ the Data Sources API. The following options are supported: partitioned and returned. + + + fetchSize + + The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows). + +
    From 8d57524662fad4a0760f3bc924e690c2a110e7f7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 23 Nov 2015 22:22:15 -0800 Subject: [PATCH 864/896] [SPARK-11933][SQL] Rename mapGroup -> mapGroups and flatMapGroup -> flatMapGroups. Based on feedback from Matei, this is more consistent with mapPartitions in Spark. Also addresses some of the cleanups from a previous commit that renames the type variables. Author: Reynold Xin Closes #9919 from rxin/SPARK-11933. --- ...nction.java => FlatMapGroupsFunction.java} | 2 +- ...upFunction.java => MapGroupsFunction.java} | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 36 +++++++++---------- .../apache/spark/sql/JavaDatasetSuite.java | 10 +++--- .../spark/sql/DatasetPrimitiveSuite.scala | 4 +-- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++---- 6 files changed, 33 insertions(+), 33 deletions(-) rename core/src/main/java/org/apache/spark/api/java/function/{FlatMapGroupFunction.java => FlatMapGroupsFunction.java} (93%) rename core/src/main/java/org/apache/spark/api/java/function/{MapGroupFunction.java => MapGroupsFunction.java} (93%) diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index 18a2d733ca70..d7a80e7b129b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -23,6 +23,6 @@ /** * A function that returns zero or more output records from each grouping key and its values. */ -public interface FlatMapGroupFunction extends Serializable { +public interface FlatMapGroupsFunction extends Serializable { Iterable call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java similarity index 93% rename from core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java rename to core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java index 4f3f222e064b..faa59eabc8b4 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -23,6 +23,6 @@ /** * Base interface for a map function used in GroupedDataset's mapGroup function. */ -public interface MapGroupFunction extends Serializable { +public interface MapGroupsFunction extends Serializable { R call(K key, Iterator values) throws Exception; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 7f43ce16901b..793a86b13290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.expressions.Aggregator @Experimental class GroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], - tEncoder: Encoder[V], + vEncoder: Encoder[V], val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { @@ -53,12 +53,12 @@ class GroupedDataset[K, V] private[sql]( // queryexecution. private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedTEncoder = encoderFor(tEncoder) + private implicit val unresolvedVEncoder = encoderFor(vEncoder) private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedTEncoder = - unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + private val resolvedVEncoder = + unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -76,7 +76,7 @@ class GroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: GroupedDataset[L, V] = new GroupedDataset( encoderFor[L], - unresolvedTEncoder, + unresolvedVEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -110,13 +110,13 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { + def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups( f, resolvedKEncoder, - resolvedTEncoder, + resolvedVEncoder, groupingAttributes, logicalPlan)) } @@ -138,8 +138,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder) + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) } /** @@ -158,9 +158,9 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { + def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { val func = (key: K, it: Iterator[V]) => Iterator(f(key, it)) - flatMapGroup(func) + flatMapGroups(func) } /** @@ -179,8 +179,8 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroup((key, data) => f.call(key, data.asJava))(encoder) + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava))(encoder) } /** @@ -192,8 +192,8 @@ class GroupedDataset[K, V] private[sql]( def reduce(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) - flatMapGroup(func) + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + flatMapGroups(func) } /** @@ -213,7 +213,7 @@ class GroupedDataset[K, V] private[sql]( private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => - tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes) + tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) case _ => c } @@ -227,7 +227,7 @@ class GroupedDataset[K, V] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map( - _.withInputType(resolvedTEncoder, dataAttributes).named) + _.withInputType(resolvedVEncoder, dataAttributes).named) val keyColumn = if (groupingAttributes.length > 1) { Alias(CreateStruct(groupingAttributes), "key")() } else { @@ -304,7 +304,7 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedTEncoder + implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index cf335efdd23b..67a3190cb7d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,7 +170,7 @@ public Integer call(String v) throws Exception { } }, Encoders.INT()); - Dataset mapped = grouped.mapGroup(new MapGroupFunction() { + Dataset mapped = grouped.mapGroups(new MapGroupsFunction() { @Override public String call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -183,8 +183,8 @@ public String call(Integer key, Iterator values) throws Exception { Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); - Dataset flatMapped = grouped.flatMapGroup( - new FlatMapGroupFunction() { + Dataset flatMapped = grouped.flatMapGroups( + new FlatMapGroupsFunction() { @Override public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); @@ -249,8 +249,8 @@ public void testGroupByColumn() { GroupedDataset grouped = ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); - Dataset mapped = grouped.mapGroup( - new MapGroupFunction() { + Dataset mapped = grouped.mapGroups( + new MapGroupsFunction() { @Override public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index d387710357be..f75d0961823c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroup { case (g, iter) => + val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } @@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() val grouped = ds.groupBy(_.length) - val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) } + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } checkAnswer( agged, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc8e4325fd2f..dbdd7ba14a5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, @@ -234,7 +234,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.flatMapGroup { case (g, iter) => + val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } @@ -255,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, @@ -265,7 +265,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -275,7 +275,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -285,7 +285,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] - val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) } + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 6cf51a7007bd72eb93ade149ca9fc53be5b32a17 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 23 Nov 2015 22:22:50 -0800 Subject: [PATCH 865/896] [SPARK-11903] Remove --skip-java-test Per [pwendell's comments on SPARK-11903](https://issues.apache.org/jira/browse/SPARK-11903?focusedCommentId=15021511&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15021511) I'm removing this dead code. If we are concerned about preserving compatibility, I can instead leave the option in and add a warning. For example: ```sh echo "Warning: '--skip-java-test' is deprecated and has no effect." ;; ``` cc pwendell, srowen Author: Nicholas Chammas Closes #9924 from nchammas/make-distribution. --- make-distribution.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/make-distribution.sh b/make-distribution.sh index d7d27e253f72..7b417fe7cf61 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -69,9 +69,6 @@ while (( "$#" )); do echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; - --skip-java-test) - SKIP_JAVA_TEST=true - ;; --with-tachyon) SPARK_TACHYON=true ;; From 4021a28ac30b65cb61cf1e041253847253a2d89f Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 23 Nov 2015 22:26:08 -0800 Subject: [PATCH 866/896] [SPARK-10707][SQL] Fix nullability computation in union output Author: Mikhail Bautin Closes #9308 from mbautin/SPARK-10707. --- .../plans/logical/basicOperators.scala | 11 +++++-- .../spark/sql/execution/basicOperators.scala | 9 ++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 31 +++++++++++++++++++ 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0c444482c5e4..737e62fd5921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -92,8 +92,10 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - // TODO: These aren't really the same attributes as nullability etc might change. - final override def output: Seq[Attribute] = left.output + override def output: Seq[Attribute] = + left.output.zip(right.output).map { case (leftAttr, rightAttr) => + leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable) + } final override lazy val resolved: Boolean = childrenResolved && @@ -115,7 +117,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(lef case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) -case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + /** We don't use right.output because those rows get excluded from the set. */ + override def output: Seq[Attribute] = left.output +} case class Join( left: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e79092efdaa3..d57b8e7a9ed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -130,8 +130,13 @@ case class Sample( * Union two plans, without a distinct. This is UNION ALL in SQL. */ case class Union(children: Seq[SparkPlan]) extends SparkPlan { - // TODO: attributes output by union should be distinct for nullability purposes - override def output: Seq[Attribute] = children.head.output + override def output: Seq[Attribute] = { + children.tail.foldLeft(children.head.output) { case (currentOutput, child) => + currentOutput.zip(child.output).map { case (a1, a2) => + a1.withNullability(a1.nullable || a2.nullable) + } + } + } override def outputsUnsafeRows: Boolean = children.forall(_.outputsUnsafeRows) override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 167aea87de07..bb82b562aaaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1997,4 +1997,35 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } + + test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { + // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the + // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of + // UNION was incorrectly considered non-nullable: + checkAnswer( + sql("""SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), + Seq(Row(0))) + } + + test("SPARK-10707: nullability should be correctly propagated through set operations (2)") { + // This test uses RAND() to stop column pruning for Union and checks the resulting isnull + // value. This would produce an incorrect result before the fix in SPARK-10707 because the "v" + // column of the union was considered non-nullable. + checkAnswer( + sql( + """ + |SELECT a FROM ( + | SELECT ISNULL(v) AS a, RAND() FROM ( + | SELECT 'foo' AS v UNION ALL SELECT null AS v + | ) my_union + |) my_view + """.stripMargin), + Row(false) :: Row(true) :: Nil) + } + } From 12eea834d7382fbaa9c92182b682b8724049d7c1 Mon Sep 17 00:00:00 2001 From: Xiu Guo Date: Tue, 24 Nov 2015 00:07:40 -0800 Subject: [PATCH 867/896] [SPARK-11897][SQL] Add @scala.annotations.varargs to sql functions Author: Xiu Guo Closes #9918 from xguo27/SPARK-11897. --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b27b1340cce4..6137ce3a70fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -689,6 +689,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def array(colName: String, colNames: String*): Column = { array((colName +: colNames).map(col) : _*) } @@ -871,6 +872,7 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ + @scala.annotation.varargs def struct(colName: String, colNames: String*): Column = { struct((colName +: colNames).map(col) : _*) } From 800bd799acf7f10a469d8d6537279953129eb2c6 Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Tue, 24 Nov 2015 09:03:32 +0000 Subject: [PATCH 868/896] [SPARK-11906][WEB UI] Speculation Tasks Cause ProgressBar UI Overflow When there are speculative tasks in the stage, running progress bar could overflow and goes hidden on a new line: ![image](https://cloud.githubusercontent.com/assets/4317392/11326841/5fd3482e-9142-11e5-8ca5-cb2f0c0c8964.png) 3 completed / 2 running (including 1 speculative) out of 4 total tasks This is a simple fix by capping the started tasks at `total - completed` tasks ![image](https://cloud.githubusercontent.com/assets/4317392/11326842/6bb67260-9142-11e5-90f0-37f9174878ec.png) I should note my preferred way to fix it is via css style ```css .progress { display: flex; } ``` which shifts the correction burden from driver to web browser. However I couldn't get selenium test to measure the position/dimension of the progress bar correctly to get this unit tested. It also has the side effect that the width will be calibrated so the running occupies 2 / 5 instead of 1 / 4. ![image](https://cloud.githubusercontent.com/assets/4317392/11326848/7b03e9f0-9142-11e5-89ad-bd99cb0647cf.png) All in all, since this cosmetic bug is minor enough, I suppose the original simple fix should be good enough. Author: Forest Fang Closes #9896 from saurfang/progressbar. --- core/src/main/scala/org/apache/spark/ui/UIUtils.scala | 4 +++- .../test/scala/org/apache/spark/ui/UIUtilsSuite.scala | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 25dcb604d9e5..84a1116a5c49 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -319,7 +319,9 @@ private[spark] object UIUtils extends Logging { skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) + // started + completed can be > total when there are speculative tasks + val boundedStarted = math.min(started, total - completed) + val startWidth = "width: %s%%".format((boundedStarted.toDouble/total)*100)
    diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 2b693c165180..dd8d5ec27f87 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -57,6 +57,16 @@ class UIUtilsSuite extends SparkFunSuite { ) } + test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { + val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val expected = Seq( +
    , +
    + ) + assert(generated.sameElements(expected), + s"\nRunning progress bar should round down\n\nExpected:\n$expected\nGenerated:\n$generated") + } + private def verify( desc: String, expected: Elem, errorMsg: String = "", baseUrl: String = ""): Unit = { val generated = makeDescription(desc, baseUrl) From d4a5e6f719079639ffd38470f4d8d1f6fde3228d Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 24 Nov 2015 23:24:49 +0800 Subject: [PATCH 869/896] [SPARK-11043][SQL] BugFix:Set the operator log in the thrift server. `SessionManager` will set the `operationLog` if the configuration `hive.server2.logging.operation.enabled` is true in version of hive 1.2.1. But the spark did not adapt to this change, so no matter enabled the configuration or not, spark thrift server will always log the warn message. PS: if `hive.server2.logging.operation.enabled` is false, it should log the warn message (the same as hive thrift server). Author: huangzhaowei Closes #9056 from SaintBacchus/SPARK-11043. --- .../SparkExecuteStatementOperation.scala | 8 ++++---- .../thriftserver/SparkSQLSessionManager.scala | 5 +++++ .../thriftserver/HiveThriftServer2Suites.scala | 16 +++++++++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 82fef92dcb73..e022ee86a763 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -134,12 +134,12 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = resultSchema - override def run(): Unit = { + override def runInternal(): Unit = { setState(OperationState.PENDING) setHasResultSet(true) // avoid no resultset for async run if (!runInBackground) { - runInternal() + execute() } else { val sparkServiceUGI = Utils.getUGI() @@ -151,7 +151,7 @@ private[hive] class SparkExecuteStatementOperation( val doAsAction = new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { try { - runInternal() + execute() } catch { case e: HiveSQLException => setOperationException(e) @@ -188,7 +188,7 @@ private[hive] class SparkExecuteStatementOperation( } } - override def runInternal(): Unit = { + private def execute(): Unit = { statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index af4fcdf021bd..de4e9c62b57a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -41,6 +41,11 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) + // Create operation log root directory, if operation logging is enabled + if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) { + invoke(classOf[SessionManager], this, "initOperationLogRootDir") + } + val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) getAncestorField[Log](this, 3, "LOG").info( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1dd898aa3835..139d8e897ba1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.{Await, Promise, future} +import scala.io.Source import scala.util.{Random, Try} import com.google.common.base.Charsets.UTF_8 @@ -507,6 +508,12 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(rs2.getInt(2) === 500) } } + + test("SPARK-11043 check operation log root directory") { + val expectedLine = + "Operation log root directory is created: " + operationLogPath.getAbsoluteFile + assert(Source.fromFile(logPath).getLines().exists(_.contains(expectedLine))) + } } class SingleSessionSuite extends HiveThriftJdbcTest { @@ -642,7 +649,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl protected def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" private val pidDir: File = Utils.createTempDir("thriftserver-pid") - private var logPath: File = _ + protected var logPath: File = _ + protected var operationLogPath: File = _ private var logTailingProcess: Process = _ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] @@ -679,6 +687,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf ${ConfVars.HIVE_SERVER2_LOGGING_OPERATION_LOG_LOCATION}=$operationLogPath | --hiveconf $portConf=$port | --driver-class-path $driverClassPath | --driver-java-options -Dlog4j.debug @@ -706,6 +715,8 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl warehousePath.delete() metastorePath = Utils.createTempDir() metastorePath.delete() + operationLogPath = Utils.createTempDir() + operationLogPath.delete() logPath = null logTailingProcess = null @@ -782,6 +793,9 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl metastorePath.delete() metastorePath = null + operationLogPath.delete() + operationLogPath = null + Option(logPath).foreach(_.delete()) logPath = null From 5889880fbe9628681042036892ef7ebd4f0857b4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 24 Nov 2015 23:32:05 +0800 Subject: [PATCH 870/896] [SPARK-11592][SQL] flush spark-sql command line history to history file Currently, `spark-sql` would not flush command history when exiting. Author: Daoyuan Wang Closes #9563 from adrian-wang/jline. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 6419002a2aa8..4b928e600b35 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -194,6 +194,22 @@ private[hive] object SparkSQLCLIDriver extends Logging { logWarning(e.getMessage) } + // add shutdown hook to flush the history to history file + Runtime.getRuntime.addShutdownHook(new Thread(new Runnable() { + override def run() = { + reader.getHistory match { + case h: FileHistory => + try { + h.flush() + } catch { + case e: IOException => + logWarning("WARNING: Failed to write command history file: " + e.getMessage) + } + case _ => + } + } + })) + // TODO: missing /* val clientTransportTSocketField = classOf[CliSessionState].getDeclaredField("transport") From be9dd1550c1816559d3d418a19c692e715f1c94e Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 24 Nov 2015 09:20:09 -0800 Subject: [PATCH 871/896] =?UTF-8?q?[SPARK-11818][REPL]=20Fix=20ExecutorCla?= =?UTF-8?q?ssLoader=20to=20lookup=20resources=20from=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …parent class loader Without patch, two additional tests of ExecutorClassLoaderSuite fails. - "resource from parent" - "resources from parent" Detailed explanation is here, https://issues.apache.org/jira/browse/SPARK-11818?focusedCommentId=15011202&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15011202 Author: Jungtaek Lim Closes #9812 from HeartSaVioR/SPARK-11818. --- .../spark/repl/ExecutorClassLoader.scala | 12 +++++++- .../spark/repl/ExecutorClassLoaderSuite.scala | 29 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a976e96809cb..a8859fcd4584 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -34,7 +34,9 @@ import org.apache.spark.util.ParentClassLoader /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. - * Allows the user to specify if user class path should be first + * Allows the user to specify if user class path should be first. + * This class loader delegates getting/finding resources to parent loader, + * which makes sense until REPL never provide resource dynamically. */ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader with Logging { @@ -55,6 +57,14 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } } + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + override def findClass(name: String): Class[_] = { userClassPathFirst match { case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index a58eda12b112..c1211f7596b9 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -19,8 +19,13 @@ package org.apache.spark.repl import java.io.File import java.net.{URL, URLClassLoader} +import java.nio.charset.StandardCharsets +import java.util + +import com.google.common.io.Files import scala.concurrent.duration._ +import scala.io.Source import scala.language.implicitConversions import scala.language.postfixOps @@ -41,6 +46,7 @@ class ExecutorClassLoaderSuite val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val parentResourceNames = List("fake-resource.txt") var tempDir1: File = _ var tempDir2: File = _ var url1: String = _ @@ -54,6 +60,9 @@ class ExecutorClassLoaderSuite url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentResourceNames.foreach { x => + Files.write("resource".getBytes(StandardCharsets.UTF_8), new File(tempDir2, x)) + } parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) } @@ -99,6 +108,26 @@ class ExecutorClassLoaderSuite } } + test("resource from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val is = classLoader.getResourceAsStream(resourceName) + assert(is != null, s"Resource $resourceName not found") + val content = Source.fromInputStream(is, "UTF-8").getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } + + test("resources from parent") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) + val resourceName: String = parentResourceNames.head + val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) + assert(resources.hasMoreElements, s"Resource $resourceName not found") + val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } + test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") { // This is a regression test for SPARK-6209, a bug where each failed attempt to load a class // from the driver's class server would leak a HTTP connection, causing the class server's From e5aaae6e1145b8c25c4872b2992ab425da9c6f9b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 09:28:39 -0800 Subject: [PATCH 872/896] [SPARK-11942][SQL] fix encoder life cycle for CoGroup we should pass in resolved encodera to logical `CoGroup` and bind them in physical `CoGroup` Author: Wenchen Fan Closes #9928 from cloud-fan/cogroup. --- .../plans/logical/basicOperators.scala | 27 ++++++++++--------- .../org/apache/spark/sql/GroupedDataset.scala | 4 ++- .../spark/sql/execution/basicOperators.scala | 20 +++++++------- .../org/apache/spark/sql/DatasetSuite.scala | 12 +++++++++ 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 737e62fd5921..5665fd7e5f41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -553,19 +553,22 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], + def apply[Key, Left, Right, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup[K, Left, Right, R] = { + right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { CoGroup( func, - encoderFor[K], - encoderFor[Left], - encoderFor[Right], - encoderFor[R], - encoderFor[R].schema.toAttributes, + keyEnc, + leftEnc, + rightEnc, + encoderFor[Result], + encoderFor[Result].schema.toAttributes, leftGroup, rightGroup, left, @@ -577,12 +580,12 @@ object CoGroup { * A relation produced by applying `func` to each grouping key and associated values from left and * right children. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 793a86b13290..a10a89342fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( f, + this.resolvedKEncoder, + this.resolvedVEncoder, + other.resolvedVEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d57b8e7a9ed6..a42aea0b96d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -375,12 +375,12 @@ case class MapGroups[K, T, U]( * iterators containing all elements in the group from left and right side. * The result of this function is encoded and flattened before being output. */ -case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], - kEncoder: ExpressionEncoder[K], +case class CoGroup[Key, Left, Right, Result]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + keyEnc: ExpressionEncoder[Key], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], - rEncoder: ExpressionEncoder[R], + resultEnc: ExpressionEncoder[Result], output: Seq[Attribute], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], @@ -397,15 +397,17 @@ case class CoGroup[K, Left, Right, R]( left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val groupKeyEncoder = kEncoder.bind(leftGroup) + val boundKeyEnc = keyEnc.bind(leftGroup) + val boundLeftEnc = leftEnc.bind(left.output) + val boundRightEnc = rightEnc.bind(right.output) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => val result = func( - groupKeyEncoder.fromRow(key), - leftResult.map(leftEnc.fromRow), - rightResult.map(rightEnc.fromRow)) - result.map(rEncoder.toRow) + boundKeyEnc.fromRow(key), + leftResult.map(boundLeftEnc.fromRow), + rightResult.map(boundRightEnc.fromRow)) + result.map(resultEnc.toRow) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dbdd7ba14a5b..13eede1b17d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } + test("cogroup with complex data") { + val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() + val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a", 2 -> "bc", 3 -> "d") + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") From 56a0aba0a60326ba026056c9a23f3f6ec7258c19 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 24 Nov 2015 09:52:53 -0800 Subject: [PATCH 873/896] [SPARK-11952][ML] Remove duplicate ml examples Remove duplicate ml examples (only for ml). mengxr Author: Yanbo Liang Closes #9933 from yanboliang/SPARK-11685. --- .../main/python/ml/gradient_boosted_trees.py | 82 ----------------- .../src/main/python/ml/logistic_regression.py | 66 -------------- .../main/python/ml/random_forest_example.py | 87 ------------------- 3 files changed, 235 deletions(-) delete mode 100644 examples/src/main/python/ml/gradient_boosted_trees.py delete mode 100644 examples/src/main/python/ml/logistic_regression.py delete mode 100644 examples/src/main/python/ml/random_forest_example.py diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py deleted file mode 100644 index c3bf8aa2eb1e..000000000000 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ /dev/null @@ -1,82 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import GBTClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import GBTRegressor -from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. -Note: GBTClassifier only supports binary classification currently -Run with: - bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py -""" - - -def testClassification(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = BinaryClassificationMetrics(predictionAndLabels) - print("AUC %.3f" % metrics.areaUnderROC) - - -def testRegression(train, test): - # Train a GradientBoostedTrees model. - - rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: gradient_boosted_trees", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonGBTExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py deleted file mode 100644 index 4cd027fdfbe8..000000000000 --- a/examples/src/main/python/ml/logistic_regression.py +++ /dev/null @@ -1,66 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.evaluation import MulticlassMetrics -from pyspark.ml.feature import StringIndexer -from pyspark.sql import SQLContext - -""" -A simple example demonstrating a logistic regression with elastic net regularization Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/logistic_regression.py -""" - -if __name__ == "__main__": - - if len(sys.argv) > 1: - print("Usage: logistic_regression", file=sys.stderr) - exit(-1) - - sc = SparkContext(appName="PythonLogisticRegressionExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [training, test] = td.randomSplit([0.7, 0.3]) - - lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") - lr.setElasticNetParam(0.8) - - # Fit the model - lrModel = lr.fit(training) - - predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py deleted file mode 100644 index dc6a77867019..000000000000 --- a/examples/src/main/python/ml/random_forest_example.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import sys - -from pyspark import SparkContext -from pyspark.ml.classification import RandomForestClassifier -from pyspark.ml.feature import StringIndexer -from pyspark.ml.regression import RandomForestRegressor -from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils -from pyspark.sql import Row, SQLContext - -""" -A simple example demonstrating a RandomForest Classification/Regression Pipeline. -Run with: - bin/spark-submit examples/src/main/python/ml/random_forest_example.py -""" - - -def testClassification(train, test): - # Train a RandomForest model. - # Setting featureSubsetStrategy="auto" lets the algorithm choose. - # Note: Use larger numTrees in practice. - - rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = MulticlassMetrics(predictionAndLabels) - print("weighted f-measure %.3f" % metrics.weightedFMeasure()) - print("precision %s" % metrics.precision()) - print("recall %s" % metrics.recall()) - - -def testRegression(train, test): - # Train a RandomForest model. - # Note: Use larger numTrees in practice. - - rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) - - model = rf.fit(train) - predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ - .map(lambda x: (x.prediction, x.indexedLabel)) - - metrics = RegressionMetrics(predictionAndLabels) - print("rmse %.3f" % metrics.rootMeanSquaredError) - print("r2 %.3f" % metrics.r2) - print("mae %.3f" % metrics.meanAbsoluteError) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: random_forest_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonRandomForestExample") - sqlContext = SQLContext(sc) - - # Load the data stored in LIBSVM format as a DataFrame. - df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - - # Map labels into an indexed column of labels in [0, numLabels) - stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") - si_model = stringIndexer.fit(df) - td = si_model.transform(df) - [train, test] = td.randomSplit([0.7, 0.3]) - testClassification(train, test) - testRegression(train, test) - sc.stop() From 9e24ba667e43290fbaa3cacb93cf5d9be790f1fd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 24 Nov 2015 09:54:55 -0800 Subject: [PATCH 874/896] [SPARK-11521][ML][DOC] Document that Logistic, Linear Regression summaries ignore weight col Doc for 1.6 that the summaries mostly ignore the weight column. To be corrected for 1.7 CC: mengxr thunterdb Author: Joseph K. Bradley Closes #9927 from jkbradley/linregsummary-doc. --- .../ml/classification/LogisticRegression.scala | 18 ++++++++++++++++++ .../spark/ml/regression/LinearRegression.scala | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 418bbdc9a058..d320d64dd90d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -755,23 +755,35 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the receiver operating characteristic (ROC) curve, * which is an Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") /** * Computes the area under the receiver operating characteristic (ROC) curve. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** * Returns the precision-recall curve, which is an Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val fMeasureByThreshold: DataFrame = { binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") @@ -781,6 +793,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, precision) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val precisionByThreshold: DataFrame = { binaryMetrics.precisionByThreshold().toDF("threshold", "precision") @@ -790,6 +805,9 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns a dataframe with two fields (threshold, recall) curve. * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * This will change in later Spark versions. */ @transient lazy val recallByThreshold: DataFrame = { binaryMetrics.recallByThreshold().toDF("threshold", "recall") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70ccec766c47..1db91666f21a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -540,6 +540,9 @@ class LinearRegressionSummary private[regression] ( * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -547,6 +550,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -554,6 +560,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -561,6 +570,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -568,6 +580,9 @@ class LinearRegressionSummary private[regression] ( /** * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * + * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 From 52bc25c8e26d4be250d8ff7864067528f4f98592 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 24 Nov 2015 09:56:17 -0800 Subject: [PATCH 875/896] [SPARK-11847][ML] Model export/import for spark.ml: LDA Add read/write support to LDA, similar to ALS. save/load for ml.LocalLDAModel is done. For DistributedLDAModel, I'm not sure if we can invoke save on the mllib.DistributedLDAModel directly. I'll send update after some test. Author: Yuhao Yang Closes #9894 from hhbyyh/ldaMLsave. --- .../org/apache/spark/ml/clustering/LDA.scala | 110 +++++++++++++++++- .../spark/mllib/clustering/LDAModel.scala | 4 +- .../apache/spark/ml/clustering/LDASuite.scala | 44 ++++++- 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 92e05815d6a3..830510b1698d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, @Since("1.6.0") @transient protected val sqlContext: SQLContext) - extends Model[LDAModel] with LDAParams with Logging { + extends Model[LDAModel] with LDAParams with Logging with MLWritable { // NOTE to developers: // This abstraction should contain all important functionality for basic LDA usage. @@ -486,6 +487,64 @@ class LocalLDAModel private[ml] ( @Since("1.6.0") override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) } @@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] ( */ @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) } @@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") @Experimental class LDA @Since("1.6.0") ( - @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("lda")) @@ -695,7 +794,7 @@ class LDA @Since("1.6.0") ( } -private[clustering] object LDA { +private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { @@ -706,4 +805,7 @@ private[clustering] object LDA { (docId, features) } } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index cd520f09bd46..7384d065a2ea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") -class LocalLDAModel private[clustering] ( +class LocalLDAModel private[spark] ( @Since("1.3.0") val topics: Matrix, @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double = 100) + override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { @Since("1.3.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b634d31cc34f..97dbfd9a4314 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -39,10 +40,24 @@ object LDASuite { }.map(v => new TestRow(v)) sql.createDataFrame(rdd) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051 + ) } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { val k: Int = 5 val vocabSize: Int = 30 @@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val lp = model.logPrior assert(lp <= 0.0 && lp != Double.NegativeInfinity) } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } } From 19530da6903fa59b051eec69b9c17e231c68454b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2015 11:09:01 -0800 Subject: [PATCH 876/896] [SPARK-11926][SQL] unify GetStructField and GetInternalRowField Author: Wenchen Fan Closes #9909 from cloud-fan/get-struct. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 8 +++---- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/complexTypeExtractors.scala | 18 ++++++++-------- .../expressions/namedExpressions.scala | 4 ++-- .../sql/catalyst/expressions/objects.scala | 21 ------------------- .../expressions/ComplexTypeSuite.scala | 4 ++-- 9 files changed, 21 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 476becec4dd5..d133ad3f0d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) + .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) /** Returns the current path or `BoundReference`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6485bdfb3023..1b2a8dc4c7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. attribute.get.dataType match { - case s: StructType => { - s.fields.map( f => { - val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) Alias(extract, target.get + "." + f.name)() - }) } + case _ => { throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7bc9aed0b204..0c10a56c555f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -111,7 +111,7 @@ object ExpressionEncoder { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index fa553e7c5324..67518f52d4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -220,7 +220,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetInternalRowField(input, i, f.dataType))) + constructorFor(GetStructField(input, i))) } CreateExternalRow(convertedFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 540ed3500616..169435a10ea2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyString: String = { transform { - case a: AttributeReference => PrettyAttribute(a.name) + case a: AttributeReference => PrettyAttribute(a.name, a.dataType) case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f871b737fff3..10ce10aaf6da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -51,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -97,18 +97,18 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. - * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - override def dataType: DataType = child.dataType match { - case s: StructType => s(ordinal).dataType - // This is a hack to avoid breaking existing code until we remove the need for the struct field - case _ => field.dataType - } + private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + + override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, field.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00b7970bd16c..26b6aca79971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -273,7 +273,8 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute(name: String, dataType: DataType = NullType) + extends Attribute with Unevaluable { override def toString: String = name @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4a1f419f0ad8..62d09f0f5510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { } } -case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) - extends UnaryExpression { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ - }) - } -} - /** * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb423..62fd47234b33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } From 81012546ee5a80d2576740af0dad067b0f5962c5 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 24 Nov 2015 12:22:33 -0800 Subject: [PATCH 877/896] [SPARK-11872] Prevent the call to SparkContext#stop() in the listener bus's thread This is continuation of SPARK-11761 Andrew suggested adding this protection. See tail of https://github.com/apache/spark/pull/9741 Author: tedyu Closes #9852 from tedyu/master. --- .../scala/org/apache/spark/SparkContext.scala | 4 +++ .../spark/scheduler/SparkListenerSuite.scala | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b153a7b08e59..e19ba113702c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1694,6 +1694,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Shut down the SparkContext. def stop() { + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop SparkContext within listener thread of" + + " AsynchronousListenerBus") + } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. if (!stopped.compareAndSet(false, true)) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 84e545851f49..f20d5be7c0ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers +import org.apache.spark.SparkException import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} @@ -36,6 +37,21 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L + test("don't call sc.stop in listener") { + sc = new SparkContext("local", "SparkListenerSuite") + val listener = new SparkContextStoppingListener(sc) + val bus = new LiveListenerBus + bus.addListener(listener) + + // Starting listener bus should flush all buffered events + bus.start(sc) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + + bus.stop() + assert(listener.sparkExSeen) + } + test("basic creation and shutdown of LiveListenerBus") { val counter = new BasicJobCounter val bus = new LiveListenerBus @@ -443,6 +459,21 @@ private class BasicJobCounter extends SparkListener { override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } +/** + * A simple listener that tries to stop SparkContext. + */ +private class SparkContextStoppingListener(val sc: SparkContext) extends SparkListener { + @volatile var sparkExSeen = false + override def onJobEnd(job: SparkListenerJobEnd): Unit = { + try { + sc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} + private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 From f3152722791b163fa66597b3684009058195ba33 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 12:54:37 -0800 Subject: [PATCH 878/896] [SPARK-11946][SQL] Audit pivot API for 1.6. Currently pivot's signature looks like ```scala scala.annotation.varargs def pivot(pivotColumn: Column, values: Column*): GroupedData scala.annotation.varargs def pivot(pivotColumn: String, values: Any*): GroupedData ``` I think we can remove the one that takes "Column" types, since callers should always be passing in literals. It'd also be more clear if the values are not varargs, but rather Seq or java.util.List. I also made similar changes for Python. Author: Reynold Xin Closes #9929 from rxin/SPARK-11946. --- .../apache/spark/scheduler/DAGScheduler.scala | 1 - python/pyspark/sql/group.py | 12 +- .../sql/catalyst/expressions/literals.scala | 1 + .../org/apache/spark/sql/GroupedData.scala | 154 ++++++++++-------- .../apache/spark/sql/JavaDataFrameSuite.java | 16 ++ .../spark/sql/DataFramePivotSuite.scala | 21 +-- .../apache/spark/sql/test/SQLTestData.scala | 1 + 7 files changed, 125 insertions(+), 81 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ae725b467d8c..77a184dfe4be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1574,7 +1574,6 @@ class DAGScheduler( } def stop() { - logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 227f40bc3cf5..d8ed7eb2dda6 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -168,20 +168,24 @@ def sum(self, *cols): """ @since(1.6) - def pivot(self, pivot_col, *values): + def pivot(self, pivot_col, values=None): """Pivots a column of the current DataFrame and preform the specified aggregation. :param pivot_col: Column to pivot :param values: Optional list of values of pivotColumn that will be translated to columns in the output data frame. If values are not provided the method with do an immediate call to .distinct() on the pivot column. - >>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect() + + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ - jgd = self._jdf.pivot(_to_java_column(pivot_col), - _to_seq(self.sql_ctx._sc, values, _create_column_from_literal)) + if values is None: + jgd = self._jdf.pivot(pivot_col) + else: + jgd = self._jdf.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e34fd49be838..68ec688c99f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -44,6 +44,7 @@ object Literal { case a: Array[Byte] => Literal(a, BinaryType) case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) + case v: Literal => v case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 63dd7fbcbe9e..ee7150cbbfbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.NumericType /** @@ -282,74 +282,96 @@ class GroupedData protected[sql]( } /** - * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified - * aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) - * // Or without specifying column values - * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { - case _: GroupedData.PivotType => - throw new UnsupportedOperationException("repeated pivots are not supported") - case GroupedData.GroupByType => - val pivotValues = if (values.nonEmpty) { - values.map { - case Column(literal: Literal) => literal - case other => - throw new UnsupportedOperationException( - s"The values of a pivot must be literals, found $other") - } - } else { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .sort(pivotColumn) - .map(_.get(0)) - .take(maxValues + 1) - .map(Literal(_)).toSeq - if (values.length > maxValues) { - throw new RuntimeException( - s"The pivot column $pivotColumn has more than $maxValues distinct values, " + - "this could indicate an error. " + - "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + - s"to at least the number of distinct values of the pivot column.") - } - values - } - new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) - case _ => - throw new UnsupportedOperationException("pivot is only supported after a groupBy") + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): GroupedData = { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .toSeq + + if (values.length > maxValues) { + throw new AnalysisException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + + "to at least the number of distinct values of the pivot column.") + } + + pivot(pivotColumn, values) } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") - * // Or without specifying column values - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * @param pivotColumn Column to pivot - * @param values Optional list of values of pivotColumn that will be translated to columns in the - * output data frame. If values are not provided the method with do an immediate - * call to .distinct() on the pivot column. - * @since 1.6.0 - */ - @scala.annotation.varargs - def pivot(pivotColumn: String, values: Any*): GroupedData = { - val resolvedPivotColumn = Column(df.resolve(pivotColumn)) - pivot(resolvedPivotColumn, values.map(functions.lit): _*) + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + groupType match { + case GroupedData.GroupByType => + new GroupedData( + df, + groupingExprs, + GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + pivot(pivotColumn, values.asScala) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 567bdddece80..a12fed3c0c6a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -282,4 +282,20 @@ public void testSampleBy() { Assert.assertEquals(1, actual[1].getLong(0)); Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); } + + @Test + public void pivot() { + DataFrame df = context.table("courseSales"); + Row[] actual = df.groupBy("year") + .pivot("course", Arrays.asList("dotNET", "Java")) + .agg(sum("earnings")).orderBy("year").collect(); + + Assert.assertEquals(2012, actual[0].getInt(0)); + Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + + Assert.assertEquals(2013, actual[1].getInt(0)); + Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 0c23d142670c..fc53aba68ebb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with literals") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) @@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with literals") { checkAnswer( - courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot courses with literals and multiple aggregations") { checkAnswer( - courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + courseSales.groupBy($"year") + .pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings"), avg($"earnings")), Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil @@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot year with string values (cast)") { checkAnswer( - courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } test("pivot year with int values") { checkAnswer( - courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } @@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ test("pivot courses with no values") { // Note Java comes before dotNet in sorted order checkAnswer( - courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + courseSales.groupBy("year").pivot("course").agg(sum($"earnings")), Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil ) } test("pivot year with no values") { checkAnswer( - courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + courseSales.groupBy("course").pivot("year").agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot max values inforced") { + test("pivot max values enforced") { sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) - intercept[RuntimeException]( - courseSales.groupBy($"year").pivot($"course") + intercept[AnalysisException]( + courseSales.groupBy("year").pivot("course") ) sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index abad0d7eaaed..83c63e04f344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self => person salary complexData + courseSales } } From e6dd237463d2de8c506f0735dfdb3f43e8122513 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 24 Nov 2015 15:08:02 -0600 Subject: [PATCH 879/896] [SPARK-11929][CORE] Make the repl log4j configuration override the root logger. In the default Spark distribution, there are currently two separate log4j config files, with different default values for the root logger, so that when running the shell you have a different default log level. This makes the shell more usable, since the logs don't overwhelm the output. But if you install a custom log4j.properties, you lose that, because then it's going to be used no matter whether you're running a regular app or the shell. With this change, the overriding of the log level is done differently; the log level repl's main class (org.apache.spark.repl.Main) is used to define the root logger's level when running the shell, defaulting to WARN if it's not set explicitly. On a somewhat related change, the shell output about the "sc" variable was changed a bit to contain a little more useful information about the application, since when the root logger's log level is WARN, that information is never shown to the user. Author: Marcelo Vanzin Closes #9816 from vanzin/shell-logging. --- conf/log4j.properties.template | 5 +++ .../spark/log4j-defaults-repl.properties | 33 -------------- .../apache/spark/log4j-defaults.properties | 5 +++ .../main/scala/org/apache/spark/Logging.scala | 45 ++++++++++--------- .../apache/spark/repl/SparkILoopInit.scala | 21 ++++----- .../org/apache/spark/repl/SparkILoop.scala | 25 ++++++----- 6 files changed, 57 insertions(+), 77 deletions(-) delete mode 100644 core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index f3046be54d7c..9809b0c82848 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties deleted file mode 100644 index c85abc35b93b..000000000000 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ /dev/null @@ -1,33 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the console -log4j.rootCategory=WARN, console -log4j.appender.console=org.apache.log4j.ConsoleAppender -log4j.appender.console.target=System.err -log4j.appender.console.layout=org.apache.log4j.PatternLayout -log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - -# Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR -log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO -log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO - -# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support -log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL -log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index d44cc85dcbd8..0750488e4adf 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -22,6 +22,11 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n +# Set the default spark-shell log level to WARN. When running the spark-shell, the +# log level for this class is used to overwrite the root logger's log level, so that +# the user can have different defaults for the shell and regular Spark apps. +log4j.logger.org.apache.spark.repl.Main=WARN + # Settings to quiet third party logs that are too verbose log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 69f6e06ee005..e35e158c7e8a 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.log4j.{LogManager, PropertyConfigurator} +import org.apache.log4j.{Level, LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder @@ -119,30 +119,31 @@ trait Logging { val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass) if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements + // scalastyle:off println if (!log4j12Initialized) { - // scalastyle:off println - if (Utils.isInInterpreter) { - val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" - Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") - System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") - case None => - System.err.println(s"Spark was unable to load $replDefaultLogProps") - } - } else { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") - } + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") } - // scalastyle:on println } + + if (Utils.isInInterpreter) { + // Use the repl's main class to define the default log level when running the shell, + // overriding the root logger's config if they're different. + val rootLogger = LogManager.getRootLogger() + val replLogger = LogManager.getLogger("org.apache.spark.repl.Main") + val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN) + if (replLevel != rootLogger.getEffectiveLevel()) { + System.err.printf("Setting default log level to \"%s\".\n", replLevel) + System.err.println("To adjust logging level use sc.setLogLevel(newLevel).") + rootLogger.setLevel(replLevel) + } + } + // scalastyle:on println } Logging.initialized = true diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index bd3314d94eed..99e1e1df33fd 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,18 +123,19 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) command(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.interp.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) command("import org.apache.spark.SparkContext._") command("import sqlContext.implicits._") diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 33d262558b1f..e91139fb29f6 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -37,18 +37,19 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" - @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() - println("Spark context available as sc.") - _sc - } + @transient val sc = { + val _sc = org.apache.spark.repl.Main.createSparkContext() + println("Spark context available as sc " + + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") + _sc + } """) processLine(""" - @transient val sqlContext = { - val _sqlContext = org.apache.spark.repl.Main.createSQLContext() - println("SQL context available as sqlContext.") - _sqlContext - } + @transient val sqlContext = { + val _sqlContext = org.apache.spark.repl.Main.createSQLContext() + println("SQL context available as sqlContext.") + _sqlContext + } """) processLine("import org.apache.spark.SparkContext._") processLine("import sqlContext.implicits._") @@ -85,7 +86,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) /** Available commands */ override def commands: List[LoopCommand] = sparkStandardCommands - /** + /** * We override `loadFiles` because we need to initialize Spark *before* the REPL * sees any files, so that the Spark context is visible in those files. This is a bit of a * hack, but there isn't another hook available to us at this point. @@ -98,7 +99,7 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) object SparkILoop { - /** + /** * Creates an interpreter loop with default settings and feeds * the given code to it as input. */ From 58d9b260556a89a3d0832d583acafba1df7c6751 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Nov 2015 14:33:28 -0800 Subject: [PATCH 880/896] [SPARK-11805] free the array in UnsafeExternalSorter during spilling After calling spill() on SortedIterator, the array inside InMemorySorter is not needed, it should be freed during spilling, this could help to join multiple tables with limited memory. Author: Davies Liu Closes #9793 from davies/free_array. --- .../unsafe/sort/UnsafeExternalSorter.java | 10 +++--- .../unsafe/sort/UnsafeInMemorySorter.java | 31 ++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 9a7b2ad06cab..2e4031267473 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -468,6 +468,12 @@ public long spill() throws IOException { } allocatedPages.clear(); } + + // in-memory sorter will not be used after spilling + assert(inMemSorter != null); + released += inMemSorter.getMemoryUsage(); + inMemSorter.free(); + inMemSorter = null; return released; } } @@ -489,10 +495,6 @@ public void loadNext() throws IOException { } upstream = nextUpstream; nextUpstream = null; - - assert(inMemSorter != null); - inMemSorter.free(); - inMemSorter = null; } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index a218ad4623f4..dce1f15a2963 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -108,6 +108,7 @@ public UnsafeInMemorySorter( */ public void free() { consumer.freeArray(array); + array = null; } public void reset() { @@ -160,28 +161,22 @@ public void insertRecord(long recordPointer, long keyPrefix) { pos++; } - public static final class SortedIterator extends UnsafeSorterIterator { + public final class SortedIterator extends UnsafeSorterIterator { - private final TaskMemoryManager memoryManager; - private final int sortBufferInsertPosition; - private final LongArray sortBuffer; - private int position = 0; + private final int numRecords; + private int position; private Object baseObject; private long baseOffset; private long keyPrefix; private int recordLength; - private SortedIterator( - TaskMemoryManager memoryManager, - int sortBufferInsertPosition, - LongArray sortBuffer) { - this.memoryManager = memoryManager; - this.sortBufferInsertPosition = sortBufferInsertPosition; - this.sortBuffer = sortBuffer; + private SortedIterator(int numRecords) { + this.numRecords = numRecords; + this.position = 0; } public SortedIterator clone () { - SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + SortedIterator iter = new SortedIterator(numRecords); iter.position = position; iter.baseObject = baseObject; iter.baseOffset = baseOffset; @@ -192,21 +187,21 @@ public SortedIterator clone () { @Override public boolean hasNext() { - return position < sortBufferInsertPosition; + return position / 2 < numRecords; } public int numRecordsLeft() { - return (sortBufferInsertPosition - position) / 2; + return numRecords - position / 2; } @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer.get(position); + final long recordPointer = array.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer.get(position + 1); + keyPrefix = array.get(position + 1); position += 2; } @@ -229,6 +224,6 @@ public void loadNext() { */ public SortedIterator getSortedIterator() { sorter.sort(array, 0, pos / 2, sortComparator); - return new SortedIterator(memoryManager, pos, array); + return new SortedIterator(pos / 2); } } From 34ca392da7097a1fbe48cd6c3ebff51453ca26ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 14:51:01 -0800 Subject: [PATCH 881/896] Added a line of comment to explain why the extra sort exists in pivot. --- sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee7150cbbfbc..abd531c4ba54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -304,7 +304,7 @@ class GroupedData protected[sql]( // Get the distinct values of the column and sort them so its consistent val values = df.select(pivotColumn) .distinct() - .sort(pivotColumn) + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order .map(_.get(0)) .take(maxValues + 1) .toSeq From c7f95df5c6d8eb2e6f11cf58b704fea34326a5f2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Nov 2015 14:59:14 -0800 Subject: [PATCH 882/896] [SPARK-11783][SQL] Fixes execution Hive client when using remote Hive metastore When using remote Hive metastore, `hive.metastore.uris` is set to the metastore URI. However, it overrides `javax.jdo.option.ConnectionURL` unexpectedly, thus the execution Hive client connects to the actual remote Hive metastore instead of the Derby metastore created in the temporary directory. Cleaning this configuration for the execution Hive client fixes this issue. Author: Cheng Lian Closes #9895 from liancheng/spark-11783.clean-remote-metastore-config. --- .../org/apache/spark/sql/hive/HiveContext.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0bb5af7d5c8..8a4264194ae8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -736,6 +736,21 @@ private[hive] object HiveContext { s"jdbc:derby:;databaseName=${localMetastore.getAbsolutePath};create=true") propMap.put("datanucleus.rdbms.datastoreAdapterClassName", "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + + // SPARK-11783: When "hive.metastore.uris" is set, the metastore connection mode will be + // remote (https://cwiki.apache.org/confluence/display/Hive/AdminManual+MetastoreAdmin + // mentions that "If hive.metastore.uris is empty local mode is assumed, remote otherwise"). + // Remote means that the metastore server is running in its own process. + // When the mode is remote, configurations like "javax.jdo.option.ConnectionURL" will not be + // used (because they are used by remote metastore server that talks to the database). + // Because execution Hive should always connects to a embedded derby metastore. + // We have to remove the value of hive.metastore.uris. So, the execution Hive client connects + // to the actual embedded derby metastore instead of the remote metastore. + // You can search HiveConf.ConfVars.METASTOREURIS in the code of HiveConf (in Hive's repo). + // Then, you will find that the local metastore mode is only set to true when + // hive.metastore.uris is not set. + propMap.put(ConfVars.METASTOREURIS.varname, "") + propMap.toMap } From 238ae51b66ac12d15fba6aff061804004c5ca6cb Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 24 Nov 2015 15:54:10 -0800 Subject: [PATCH 883/896] [SPARK-11914][SQL] Support coalesce and repartition in Dataset APIs This PR is to provide two common `coalesce` and `repartition` in Dataset APIs. After reading the comments of SPARK-9999, I am unclear about the plan for supporting re-partitioning in Dataset APIs. Currently, both RDD APIs and Dataframe APIs provide users such a flexibility to control the number of partitions. In most traditional RDBMS, they expose the number of partitions, the partitioning columns, the table partitioning methods to DBAs for performance tuning and storage planning. Normally, these parameters could largely affect the query performance. Since the actual performance depends on the workload types, I think it is almost impossible to automate the discovery of the best partitioning strategy for all the scenarios. I am wondering if Dataset APIs are planning to hide these APIs from users? Feel free to reject my PR if it does not match the plan. Thank you for your answers. marmbrus rxin cloud-fan Author: gatorsmile Closes #9899 from gatorsmile/coalesce. --- .../scala/org/apache/spark/sql/Dataset.scala | 19 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 07647508421a..17e2611790d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -152,6 +152,25 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * @since 1.6.0 + */ + def repartition(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = true, _) + } + + /** + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @since 1.6.0 + */ + def coalesce(numPartitions: Int): Dataset[T] = withPlan { + Repartition(numPartitions, shuffle = false, _) + } + /* *********************** * * Functional Operations * * *********************** */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 13eede1b17d8..c253fdbb8c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -52,6 +52,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.takeAsList(1).get(0) == item) } + test("coalesce, repartition") { + val data = (1 to 100).map(i => ClassData(i.toString, i)) + val ds = data.toDS() + + assert(ds.repartition(10).rdd.partitions.length == 10) + checkAnswer( + ds.repartition(10), + data: _*) + + assert(ds.coalesce(1).rdd.partitions.length == 1) + checkAnswer( + ds.coalesce(1), + data: _*) + } + test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") checkAnswer( From 25bbd3c16e8e8be4d2c43000223d54650e9a3696 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:16:07 -0800 Subject: [PATCH 884/896] [SPARK-11967][SQL] Consistent use of varargs for multiple paths in DataFrameReader This patch makes it consistent to use varargs in all DataFrameReader methods, including Parquet, JSON, text, and the generic load function. Also added a few more API tests for the Java API. Author: Reynold Xin Closes #9945 from rxin/SPARK-11967. --- python/pyspark/sql/readwriter.py | 19 ++++++---- .../apache/spark/sql/DataFrameReader.scala | 36 +++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 23 ++++++++++++ sql/core/src/test/resources/text-suite2.txt | 1 + .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- 5 files changed, 66 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/test/resources/text-suite2.txt diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e8f0d7ec7703..2e75f0c8a182 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -109,7 +109,7 @@ def options(self, **options): def load(self, path=None, format=None, schema=None, **options): """Loads data from a data source and returns it as a :class`DataFrame`. - :param path: optional string for file-system backed data sources. + :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. :param schema: optional :class:`StructType` for the input schema. :param options: all other string options @@ -118,6 +118,7 @@ def load(self, path=None, format=None, schema=None, **options): ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes @@ -130,10 +131,8 @@ def load(self, path=None, format=None, schema=None, **options): self.options(**options) if path is not None: if type(path) == list: - paths = path - gateway = self._sqlContext._sc._gateway - jpaths = utils.toJArray(gateway, gateway.jvm.java.lang.String, paths) - return self._df(self._jreader.load(jpaths)) + return self._df( + self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load(path)) else: @@ -175,6 +174,8 @@ def json(self, path, schema=None): self.schema(schema) if isinstance(path, basestring): return self._df(self._jreader.json(path)) + elif type(path) == list: + return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): return self._df(self._jreader.json(path._jrdd)) else: @@ -205,16 +206,20 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) - def text(self, path): + def text(self, paths): """Loads a text file and returns a [[DataFrame]] with a single string column named "text". Each line in the text file is a new row in the resulting DataFrame. + :param paths: string, or list of strings, for input path(s). + >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ - return self._df(self._jreader.text(path)) + if isinstance(paths, basestring): + paths = [paths] + return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) @since(1.5) def orc(self, path): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dcb3737b70fb..3ed1e55adec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,17 +24,17 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils +import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.SqlParser import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} +import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} /** * :: Experimental :: @@ -104,6 +104,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def load(path: String): DataFrame = { option("path", path).load() } @@ -130,7 +131,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * @since 1.6.0 */ - def load(paths: Array[String]): DataFrame = { + @scala.annotation.varargs + def load(paths: String*): DataFrame = { option("paths", paths.map(StringUtils.escapeString(_, '\\', ',')).mkString(",")).load() } @@ -236,11 +238,30 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • * - * @param path input path * @since 1.4.0 */ + // TODO: Remove this one in Spark 2.0. def json(path: String): DataFrame = format("json").load(path) + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + * + * @since 1.6.0 + */ + def json(paths: String*): DataFrame = format("json").load(paths : _*) + /** * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. @@ -328,10 +349,11 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * sqlContext.read().text("/path/to/spark/README.md") * }}} * - * @param path input path + * @param paths input path * @since 1.6.0 */ - def text(path: String): DataFrame = format("text").load(path) + @scala.annotation.varargs + def text(paths: String*): DataFrame = format("text").load(paths : _*) /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index a12fed3c0c6a..8e0b2dbca4a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -298,4 +298,27 @@ public void pivot() { Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); } + + public void testGenericLoad() { + DataFrame df1 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().format("text").load( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } + + @Test + public void testTextLoad() { + DataFrame df1 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); + Assert.assertEquals(4L, df1.count()); + + DataFrame df2 = context.read().text( + Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), + Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); + Assert.assertEquals(5L, df2.count()); + } } diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/text-suite2.txt new file mode 100644 index 000000000000..f9d498c80493 --- /dev/null +++ b/sql/core/src/test/resources/text-suite2.txt @@ -0,0 +1 @@ +This is another file for testing multi path loading. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dd6d06512ff6..76e9648aa753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -897,7 +897,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + checkAnswer(sqlContext.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) checkAnswer(sqlContext.read.format("json").load(dir1), From 4d6bbbc03ddb6650b00eb638e4876a196014c19c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 18:58:55 -0800 Subject: [PATCH 885/896] [SPARK-11947][SQL] Mark deprecated methods with "This will be removed in Spark 2.0." Also fixed some documentation as I saw them. Author: Reynold Xin Closes #9930 from rxin/SPARK-11947. --- project/MimaExcludes.scala | 3 +- .../scala/org/apache/spark/sql/Column.scala | 20 +++-- .../org/apache/spark/sql/DataFrame.scala | 72 +++++++++------ .../scala/org/apache/spark/sql/Dataset.scala | 1 + .../org/apache/spark/sql/SQLContext.scala | 88 ++++++++++--------- .../org/apache/spark/sql/SQLImplicits.scala | 25 +++++- .../sql/{ => execution}/SparkSQLParser.scala | 15 ++-- .../org/apache/spark/sql/functions.scala | 52 ++++++----- .../SimpleTextHadoopFsRelationSuite.scala | 6 +- 9 files changed, 172 insertions(+), 110 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{ => execution}/SparkSQLParser.scala (89%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bb45d1bb1214..54a9ad956d11 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -108,7 +108,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") ) ++ Seq( // SPARK-11485 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 30c554a85e69..b3cd9e1eff14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -42,7 +42,8 @@ private[sql] object Column { /** * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. - * @since 1.6.0 + * To create a [[TypedColumn]], use the `as` function on a [[Column]]. + * * @tparam T The input type expected for this expression. Can be `Any` if the expression is type * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). * @tparam U The output type of this column. @@ -51,7 +52,8 @@ private[sql] object Column { */ class TypedColumn[-T, U]( expr: Expression, - private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + private[sql] val encoder: ExpressionEncoder[U]) + extends Column(expr) { /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -61,12 +63,11 @@ class TypedColumn[-T, U]( inputEncoder: ExpressionEncoder[_], schema: Seq[Attribute]): TypedColumn[T, U] = { val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U] (expr transform { - case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy( - aEncoder = Some(boundEncoder), - children = schema) - }, encoder) + new TypedColumn[T, U]( + expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy(aEncoder = Some(boundEncoder), children = schema) + }, + encoder) } } @@ -691,8 +692,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops * @since 1.3.0 + * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. */ - @deprecated("use isin", "1.5.0") + @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") @scala.annotation.varargs def in(list: Any*): Column = isin(list : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5586fc994b98..5eca1db9525e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1713,9 +1713,9 @@ class DataFrame private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `toDF()`. + * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. */ - @deprecated("use toDF", "1.3.0") + @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") def toSchemaRDD: DataFrame = this /** @@ -1725,9 +1725,9 @@ class DataFrame private[sql]( * given name; if you pass `false`, it will throw if the table already * exists. * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. + * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write w.jdbc(url, table, new Properties) @@ -1744,9 +1744,9 @@ class DataFrame private[sql]( * the RDD in order via the simple statement * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.jdbc()", "1.4.0") + @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) w.jdbc(url, table, new Properties) @@ -1757,9 +1757,9 @@ class DataFrame private[sql]( * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. + * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use write.parquet(path)", "1.4.0") + @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") def saveAsParquetFile(path: String): Unit = { write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } @@ -1782,8 +1782,9 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String): Unit = { write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } @@ -1805,8 +1806,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { write.mode(mode).saveAsTable(tableName) } @@ -1829,8 +1832,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", + "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { write.format(source).saveAsTable(tableName) } @@ -1853,8 +1858,10 @@ class DataFrame private[sql]( * * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).saveAsTable(tableName) } @@ -1877,9 +1884,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1907,9 +1915,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", - "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def saveAsTable( tableName: String, source: String, @@ -1923,9 +1932,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. + * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use write.save(path)", "1.4.0") + @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String): Unit = { write.save(path) } @@ -1935,8 +1944,9 @@ class DataFrame private[sql]( * using the default data source configured by spark.sql.sources.default. * @group output * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(mode).save(path)", "1.4.0") + @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, mode: SaveMode): Unit = { write.mode(mode).save(path) } @@ -1946,8 +1956,9 @@ class DataFrame private[sql]( * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).save(path)", "1.4.0") + @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String): Unit = { write.format(source).save(path) } @@ -1957,8 +1968,10 @@ class DataFrame private[sql]( * [[SaveMode]] specified by mode. * @group output * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") + @deprecated("Use write.format(source).mode(mode).save(path). " + + "This will be removed in Spark 2.0.", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { write.format(source).mode(mode).save(path) } @@ -1969,8 +1982,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1985,8 +2000,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().format(source).mode(mode).options(options).save(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") + @deprecated("Use write.format(source).mode(mode).options(options).save(). " + + "This will be removed in Spark 2.0.", "1.4.0") def save( source: String, mode: SaveMode, @@ -1994,14 +2011,15 @@ class DataFrame private[sql]( write.format(source).mode(mode).options(options).save() } - /** * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } @@ -2012,8 +2030,10 @@ class DataFrame private[sql]( * @group output * @deprecated As of 1.4.0, replaced by * `write().mode(SaveMode.Append).saveAsTable(tableName)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + + "This will be removed in Spark 2.0.", "1.4.0") def insertInto(tableName: String): Unit = { write.mode(SaveMode.Append).insertInto(tableName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 17e2611790d5..dd84b8bc11e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** + * :: Experimental :: * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel * using functional or relational operations. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 39471d2fb79a..46bf544fd885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -942,33 +942,33 @@ class SQLContext private[sql]( //////////////////////////////////////////////////////////////////////////// /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { createDataFrame(rowRDD, schema) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. */ - @deprecated("use createDataFrame", "1.3.0") + @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd, beanClass) } @@ -978,9 +978,9 @@ class SQLContext private[sql]( * [[DataFrame]] if no paths are passed in. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. + * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.parquet()", "1.4.0") + @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") @scala.annotation.varargs def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { @@ -995,9 +995,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String): DataFrame = { read.json(path) } @@ -1007,18 +1007,18 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, schema: StructType): DataFrame = { read.schema(schema).json(path) } /** * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonFile(path: String, samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(path) } @@ -1029,9 +1029,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String]): DataFrame = read.json(json) /** @@ -1040,9 +1040,9 @@ class SQLContext private[sql]( * It goes through the entire dataset once to determine the schema. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) /** @@ -1050,9 +1050,9 @@ class SQLContext private[sql]( * returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1062,9 +1062,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { read.schema(schema).json(json) } @@ -1074,9 +1074,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1086,9 +1086,9 @@ class SQLContext private[sql]( * schema, returning the result as a [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. + * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. */ - @deprecated("Use read.json()", "1.4.0") + @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { read.option("samplingRatio", samplingRatio.toString).json(json) } @@ -1098,9 +1098,9 @@ class SQLContext private[sql]( * using the default data source configured by spark.sql.sources.default. * * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. + * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. */ - @deprecated("Use read.load(path)", "1.4.0") + @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String): DataFrame = { read.load(path) } @@ -1110,8 +1110,9 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).load(path)", "1.4.0") + @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") def load(path: String, source: String): DataFrame = { read.format(source).load(path) } @@ -1122,8 +1123,10 @@ class SQLContext private[sql]( * * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + * This will be removed in Spark 2.0. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: java.util.Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1135,7 +1138,8 @@ class SQLContext private[sql]( * @group genericdata * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. */ - @deprecated("Use read.format(source).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, options: Map[String, String]): DataFrame = { read.options(options).format(source).load() } @@ -1148,7 +1152,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() @@ -1162,7 +1167,8 @@ class SQLContext private[sql]( * @deprecated As of 1.4.0, replaced by * `read().format(source).schema(schema).options(options).load()`. */ - @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + @deprecated("Use read.format(source).schema(schema).options(options).load(). " + + "This will be removed in Spark 2.0.", "1.4.0") def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { read.format(source).schema(schema).options(options).load() } @@ -1172,9 +1178,9 @@ class SQLContext private[sql]( * url named table. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String): DataFrame = { read.jdbc(url, table, new Properties) } @@ -1190,9 +1196,9 @@ class SQLContext private[sql]( * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc( url: String, table: String, @@ -1210,9 +1216,9 @@ class SQLContext private[sql]( * of the [[DataFrame]]. * * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. */ - @deprecated("use read.jdbc()", "1.4.0") + @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { read.jdbc(url, table, theParts, new Properties) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 25ffdcde1771..6735d02954b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -30,19 +30,38 @@ import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * + * @since 1.6.0 */ abstract class SQLImplicits { + protected def _sqlContext: SQLContext + /** @since 1.6.0 */ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + /** @since 1.6.0 */ + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + + /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() /** @@ -84,9 +103,9 @@ abstract class SQLImplicits { DataFrameHolder(_sqlContext.createDataFrame(data)) } - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. + // Do NOT add more implicit conversions for primitive types. + // They are likely to break source compatibility by making existing implicit conversions + // ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]]. /** * Creates a single column DataFrame from an RDD[Int]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index ea8fce6ca9cf..b3e8d0d84937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -15,24 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType - /** * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * * @param fallback A function that parses an input string to a logical plan */ -private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -100,14 +99,14 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case _ ~ dbName => ShowTablesCommand(dbName) } | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { - case Some(f) => ShowFunctions(f._1, Some(f._2)) - case None => ShowFunctions(None, None) + case Some(f) => logical.ShowFunctions(f._1, Some(f._2)) + case None => logical.ShowFunctions(None, None) } ) private lazy val desc: Parser[LogicalPlan] = DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { - case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) + case isExtended ~ functionName => logical.DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6137ce3a70fd..77dd5bc72508 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql - - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -39,11 +37,11 @@ import org.apache.spark.util.Utils * "bridge" methods due to the use of covariant return types. * * {{{ - * In LegacyFunctions: - * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * // In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); * - * In functions: - * public static org.apache.spark.sql.TypedColumn avg(...); + * // In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); * }}} * * This allows us to use the same functions both in typed [[Dataset]] operations and untyped @@ -2528,8 +2526,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } @@ -2541,8 +2540,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } @@ -2554,8 +2554,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } @@ -2567,8 +2568,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } @@ -2580,8 +2582,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } @@ -2593,8 +2596,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } @@ -2606,8 +2610,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } @@ -2619,8 +2624,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } @@ -2632,8 +2638,9 @@ object functions extends LegacyFunctions { * @group udf_funcs * @since 1.3.0 * @deprecated As of 1.5.0, since it's redundant with udf() + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } @@ -2644,9 +2651,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } @@ -2657,9 +2665,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() + * @deprecated As of 1.5.0, since it's redundant with udf(). + * This will be removed in Spark 2.0. */ - @deprecated("Use udf", "1.5.0") + @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2700,9 +2709,10 @@ object functions extends LegacyFunctions { * * @group udf_funcs * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. + * This will be removed in Spark 2.0. */ - @deprecated("Use callUDF", "1.5.0") + @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 81af684ba0bf..b554d135e4b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -80,7 +80,11 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat private var partitionedDF: DataFrame = _ - private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil) + private val partitionedDataSchema: StructType = + new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + .add("c", StringType) protected override def beforeAll(): Unit = { this.tempPath = Utils.createTempDir() From a5d988763319f63a8e2b58673dd4f9098f17c835 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 24 Nov 2015 20:58:47 -0800 Subject: [PATCH 886/896] [STREAMING][FLAKY-TEST] Catch execution context race condition in `FileBasedWriteAheadLog.close()` There is a race condition in `FileBasedWriteAheadLog.close()`, where if delete's of old log files are in progress, the write ahead log may close, and result in a `RejectedExecutionException`. This is okay, and should be handled gracefully. Example test failures: https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.6-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=spark-test/95/testReport/junit/org.apache.spark.streaming.util/BatchedWriteAheadLogWithCloseFileAfterWriteSuite/BatchedWriteAheadLog___clean_old_logs/ The reason the test fails is in `afterEach`, `writeAheadLog.close` is called, and there may still be async deletes in flight. tdas zsxwing Author: Burak Yavuz Closes #9953 from brkyvz/flaky-ss. --- .../streaming/util/FileBasedWriteAheadLog.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 72705f1a9c01..f5165f7c3912 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} import scala.collection.JavaConverters._ @@ -176,10 +176,16 @@ private[streaming] class FileBasedWriteAheadLog( } oldLogFiles.foreach { logInfo => if (!executionContext.isShutdown) { - val f = Future { deleteFile(logInfo) }(executionContext) - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + try { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } + } catch { + case e: RejectedExecutionException => + logWarning("Execution context shutdown before deleting old WriteAheadLogs. " + + "This would not affect recovery correctness.", e) } } } From 151d7c2baf18403e6e59e97c80c8bcded6148038 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Nov 2015 21:30:53 -0800 Subject: [PATCH 887/896] [SPARK-10621][SQL] Consistent naming for functions in SQL, Python, Scala Author: Reynold Xin Closes #9948 from rxin/SPARK-10621. --- python/pyspark/sql/functions.py | 111 +++++++++++++--- .../org/apache/spark/sql/functions.scala | 124 ++++++++++++++---- 2 files changed, 196 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a1ca723bbd7a..e3786e0fa5fb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -150,18 +150,18 @@ def _(): _window_functions = { 'rowNumber': - """returns a sequential number starting at 1 within a window partition. - - This is equivalent to the ROW_NUMBER function in SQL.""", + """.. note:: Deprecated in 1.6, use row_number instead.""", + 'row_number': + """returns a sequential number starting at 1 within a window partition.""", 'denseRank': + """.. note:: Deprecated in 1.6, use dense_rank instead.""", + 'dense_rank': """returns the rank of rows within a window partition, without any gaps. The difference between rank and denseRank is that denseRank leaves no gaps in ranking sequence when there are ties. That is, if you were ranking a competition using denseRank and had three people tie for second place, you would say that all three were in second - place and that the next person came in third. - - This is equivalent to the DENSE_RANK function in SQL.""", + place and that the next person came in third.""", 'rank': """returns the rank of rows within a window partition. @@ -172,14 +172,14 @@ def _(): This is equivalent to the RANK function in SQL.""", 'cumeDist': + """.. note:: Deprecated in 1.6, use cume_dist instead.""", + 'cume_dist': """returns the cumulative distribution of values within a window partition, - i.e. the fraction of rows that are below the current row. - - This is equivalent to the CUME_DIST function in SQL.""", + i.e. the fraction of rows that are below the current row.""", 'percentRank': - """returns the relative rank (i.e. percentile) of rows within a window partition. - - This is equivalent to the PERCENT_RANK function in SQL.""", + """.. note:: Deprecated in 1.6, use percent_rank instead.""", + 'percent_rank': + """returns the relative rank (i.e. percentile) of rows within a window partition.""", } for _name, _doc in _functions.items(): @@ -189,7 +189,7 @@ def _(): for _name, _doc in _binary_mathfunctions.items(): globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) for _name, _doc in _window_functions.items(): - globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) + globals()[_name] = since(1.6)(_create_window_function(_name, _doc)) for _name, _doc in _functions_1_6.items(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) del _name, _doc @@ -288,6 +288,38 @@ def countDistinct(col, *cols): @since(1.4) def monotonicallyIncreasingId(): + """ + .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. + """ + return monotonically_increasing_id() + + +@since(1.6) +def input_file_name(): + """Creates a string column for the file name of the current Spark task. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.input_file_name()) + + +@since(1.6) +def isnan(col): + """An expression that returns true iff the column is NaN. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnan(_to_java_column(col))) + + +@since(1.6) +def isnull(col): + """An expression that returns true iff the column is null. + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.isnull(_to_java_column(col))) + + +@since(1.6) +def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -300,11 +332,21 @@ def monotonicallyIncreasingId(): 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) - >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.monotonicallyIncreasingId()) + return Column(sc._jvm.functions.monotonically_increasing_id()) + + +@since(1.6) +def nanvl(col1, col2): + """Returns col1 if it is not NaN, or col2 if col1 is NaN. + + Both inputs should be floating point columns (DoubleType or FloatType). + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) @since(1.4) @@ -382,15 +424,23 @@ def shiftRightUnsigned(col, numBits): @since(1.4) def sparkPartitionId(): + """ + .. note:: Deprecated in 1.6, use spark_partition_id instead. + """ + return spark_partition_id() + + +@since(1.6) +def spark_partition_id(): """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. - >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.sparkPartitionId()) + return Column(sc._jvm.functions.spark_partition_id()) @since(1.5) @@ -1410,6 +1460,33 @@ def explode(col): return Column(jc) +@since(1.6) +def get_json_object(col, path): + """ + Extracts json object from a json string based on json path specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + :param col: string column in json format + :param path: path to the json object to extract + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.get_json_object(_to_java_column(col), path) + return Column(jc) + + +@since(1.6) +def json_tuple(col, fields): + """Creates a new row for a json column according to the given field names. + + :param col: string column in json format + :param fields: list of fields to extract + + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.json_tuple(_to_java_column(col), fields) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77dd5bc72508..276c5dfc8b06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -472,6 +472,13 @@ object functions extends LegacyFunctions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def cumeDist(): Column = cume_dist() + /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -481,13 +488,17 @@ object functions extends LegacyFunctions { * cumeDist(x) = number of values before (and including) x / N * }}} * - * - * This is equivalent to the CUME_DIST function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } + + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") + def denseRank(): Column = dense_rank() /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -497,12 +508,10 @@ object functions extends LegacyFunctions { * and had three people tie for second place, you would say that all three were in second * place and that the next person came in third. * - * This is equivalent to the DENSE_RANK function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -620,6 +629,13 @@ object functions extends LegacyFunctions { */ def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. + */ + @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") + def percentRank(): Column = percent_rank() + /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. * @@ -631,9 +647,9 @@ object functions extends LegacyFunctions { * This is equivalent to the PERCENT_RANK function in SQL. * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } /** * Window function: returns the rank of rows within a window partition. @@ -650,15 +666,20 @@ object functions extends LegacyFunctions { */ def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } + /** + * @group window_funcs + * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. + */ + @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") + def rowNumber(): Column = row_number() + /** * Window function: returns a sequential number starting at 1 within a window partition. * - * This is equivalent to the ROW_NUMBER function in SQL. - * * @group window_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -720,20 +741,43 @@ object functions extends LegacyFunctions { @scala.annotation.varargs def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. + */ + @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") + def inputFileName(): Column = input_file_name() + /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs + * @since 1.6.0 */ - def inputFileName(): Column = withExpr { InputFileName() } + def input_file_name(): Column = withExpr { InputFileName() } + + /** + * @group normal_funcs + * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. + */ + @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") + def isNaN(e: Column): Column = isnan(e) /** * Return true iff the column is NaN. * * @group normal_funcs - * @since 1.5.0 + * @since 1.6.0 + */ + def isnan(e: Column): Column = withExpr { IsNaN(e.expr) } + + /** + * Return true iff the column is null. + * + * @group normal_funcs + * @since 1.6.0 */ - def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) } + def isnull(e: Column): Column = withExpr { IsNull(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -750,7 +794,24 @@ object functions extends LegacyFunctions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } + def monotonicallyIncreasingId(): Column = monotonically_increasing_id() + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + * @since 1.6.0 + */ + def monotonically_increasing_id(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -825,15 +886,23 @@ object functions extends LegacyFunctions { */ def randn(): Column = randn(Utils.random.nextLong) + /** + * @group normal_funcs + * @since 1.4.0 + * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. + */ + @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") + def sparkPartitionId(): Column = spark_partition_id() + /** * Partition ID of the Spark task. * * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs - * @since 1.4.0 + * @since 1.6.0 */ - def sparkPartitionId(): Column = withExpr { SparkPartitionID() } + def spark_partition_id(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -2305,6 +2374,17 @@ object functions extends LegacyFunctions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Extracts json object from a json string based on json path specified, and returns json string + * of the extracted json object. It will return null if the input json string is invalid. + * + * @group collection_funcs + * @since 1.6.0 + */ + def get_json_object(e: Column, path: String): Column = withExpr { + GetJsonObject(e.expr, lit(path).expr) + } + /** * Creates a new row for a json column according to the given field names. * @@ -2313,7 +2393,7 @@ object functions extends LegacyFunctions { */ @scala.annotation.varargs def json_tuple(json: Column, fields: String*): Column = withExpr { - require(fields.length > 0, "at least 1 field name should be given.") + require(fields.nonEmpty, "at least 1 field name should be given.") JsonTuple(json.expr +: fields.map(Literal.apply)) } From 2169886883d33b33acf378ac42a626576b342df1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Nov 2015 23:13:01 -0800 Subject: [PATCH 888/896] [SPARK-11979][STREAMING] Empty TrackStateRDD cannot be checkpointed and recovered from checkpoint file This solves the following exception caused when empty state RDD is checkpointed and recovered. The root cause is that an empty OpenHashMapBasedStateMap cannot be deserialized as the initialCapacity is set to zero. ``` Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 20, localhost): java.lang.IllegalArgumentException: requirement failed: Invalid initial capacity at scala.Predef$.require(Predef.scala:233) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:96) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.(StateMap.scala:86) at org.apache.spark.streaming.util.OpenHashMapBasedStateMap.readObject(StateMap.scala:291) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:606) at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1017) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1893) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:1990) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1915) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:370) at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76) at org.apache.spark.serializer.DeserializationStream$$anon$1.getNext(Serializer.scala:181) at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73) at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:744) ``` Author: Tathagata Das Closes #9958 from tdas/SPARK-11979. --- .../spark/streaming/util/StateMap.scala | 19 +++++++----- .../spark/streaming/StateMapSuite.scala | 30 ++++++++++++------- .../streaming/rdd/TrackStateRDDSuite.scala | 10 +++++++ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 34287c3e0090..3f139ad138c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -59,7 +59,7 @@ private[streaming] object StateMap { def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) } } @@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = 64, + initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD ) extends StateMap[K, S] { self => @@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( deltaChainThreshold = deltaChainThreshold) def this(deltaChainThreshold: Int) = this( - initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - @transient @volatile private var deltaMap = - new OpenHashMap[K, StateInfo[S]](initialCapacity) + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) /** Get the session data if it exists */ override def get(key: K): Option[S] = { @@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the parent map. Keep reading records, until the limiter is reached // First read the approximate number of records to expect and allocate properly size // OpenHashMap - val parentSessionStoreSizeHint = inputStream.readInt() + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( - initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) // Read the records until the limit marking object has been reached var parentSessionLoopDone = false @@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap { class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 48d3b41b66cb..c4a01eaea739 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite { test("OpenHashMapBasedStateMap - serializing and deserializing") { val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + map1.put(1, 100, 1) map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") val map2 = map1.copy() + // Do not test compaction + assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + map2.put(3, 300, 3) map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") val map3 = map2.copy() + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) - - // Do not test compaction - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - - val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( - Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") } test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { @@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite { assert(map.deltaChainLength > deltaChainThreshold) assert(map.shouldCompact === true) - val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") assert(deser_map.deltaChainLength < deltaChainThreshold) assert(deser_map.shouldCompact === false) - assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") } test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { @@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } + private def testSerialization[MapType <: StateMap[Int, Int]]( + map: MapType, msg: String): MapType = { + val deserMap = Utils.deserialize[MapType]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assertMap(deserMap, map, 1, msg) + deserMap + } + // Assert whether all the data and operations on a state map matches that of a reference state map private def assertMap( mapToTest: StateMap[Int, Int], diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 0feb3af1abb0..3b2d43f2ce58 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) } + test("checkpointing empty state RDD") { + val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T], From 2610e06124c7fc0b2b1cfb2e3050a35ab492fb71 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 25 Nov 2015 01:02:36 -0800 Subject: [PATCH 889/896] [SPARK-11970][SQL] Adding JoinType into JoinWith and support Sample in Dataset API Except inner join, maybe the other join types are also useful when users are using the joinWith function. Thus, added the joinType into the existing joinWith call in Dataset APIs. Also providing another joinWith interface for the cartesian-join-like functionality. Please provide your opinions. marmbrus rxin cloud-fan Thank you! Author: gatorsmile Closes #9921 from gatorsmile/joinWith. --- .../scala/org/apache/spark/sql/Dataset.scala | 45 +++++++++++++++---- .../org/apache/spark/sql/DatasetSuite.scala | 36 ++++++++++++--- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index dd84b8bc11e2..97eb5b969280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function._ - +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -83,7 +83,6 @@ class Dataset[T] private[sql]( /** * Returns the schema of the encoded form of the objects in this [[Dataset]]. - * * @since 1.6.0 */ def schema: StructType = resolvedTEncoder.schema @@ -185,7 +184,6 @@ class Dataset[T] private[sql]( * .transform(featurize) * .transform(...) * }}} - * * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) @@ -453,6 +451,21 @@ class Dataset[T] private[sql]( c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + /** + * Returns a new [[Dataset]] by sampling a fraction of records. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = + withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + + /** + * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. + * @since 1.6.0 + */ + def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + /* **************** * * Set operations * * **************** */ @@ -511,13 +524,17 @@ class Dataset[T] private[sql]( * types as well as working with relational data where either side of the join has column * names in common. * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val left = this.logicalPlan val right = other.logicalPlan - val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr))) + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) val leftOutput = joined.analyzed.output.take(left.output.length) val rightOutput = joined.analyzed.output.takeRight(right.output.length) @@ -540,6 +557,18 @@ class Dataset[T] private[sql]( } } + /** + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + /* ************************** * * Gather to Driver Actions * * ************************** */ @@ -584,7 +613,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() @@ -594,7 +622,6 @@ class Dataset[T] private[sql]( * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. - * * @since 1.6.0 */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c253fdbb8c99..7d539180ded9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -185,17 +185,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(1, 2).toDS().as("b") checkAnswer( - ds1.joinWith(ds2, $"a.value" === $"b.value"), + ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } - test("joinWith, expression condition") { - val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() - val ds2 = Seq(("a", 1), ("b", 2)).toDS() + test("joinWith, expression condition, outer join") { + val nullInteger = null.asInstanceOf[Integer] + val nullString = null.asInstanceOf[String] + val ds1 = Seq(ClassNullableData("a", 1), + ClassNullableData("c", 3)).toDS() + val ds2 = Seq(("a", new Integer(1)), + ("b", new Integer(2))).toDS() checkAnswer( - ds1.joinWith(ds2, $"_1" === $"a"), - (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + ds1.joinWith(ds2, $"_1" === $"a", "outer"), + (ClassNullableData("a", 1), ("a", new Integer(1))), + (ClassNullableData("c", 3), (nullString, nullInteger)), + (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) } test("joinWith tuple with primitive, expression") { @@ -225,7 +231,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) - } test("groupBy function, keys") { @@ -367,6 +372,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1 -> "a", 2 -> "bc", 3 -> "d") } + test("sample with replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + 5, 10, 52, 73) + } + + test("sample without replacement") { + val n = 100 + val data = sparkContext.parallelize(1 to n, 2).toDS() + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + 3, 17, 27, 58, 62) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -440,6 +461,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { case class ClassData(a: String, b: Int) +case class ClassNullableData(a: String, b: Integer) /** * A class used to test serialization using encoders. This class throws exceptions when using From a0f1a11837bfffb76582499d36fbaf21a1d628cb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 25 Nov 2015 01:03:18 -0800 Subject: [PATCH 890/896] [SPARK-11981][SQL] Move implementations of methods back to DataFrame from Queryable Also added show methods to Dataset. Author: Reynold Xin Closes #9964 from rxin/SPARK-11981. --- .../org/apache/spark/sql/DataFrame.scala | 35 ++++++++- .../scala/org/apache/spark/sql/Dataset.scala | 77 ++++++++++++++++++- .../spark/sql/execution/Queryable.scala | 32 ++------ 3 files changed, 111 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 5eca1db9525e..d8319b9a97fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -112,8 +112,8 @@ private[sql] object DataFrame { */ @Experimental class DataFrame private[sql]( - @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) + @transient override val sqlContext: SQLContext, + @DeveloperApi @transient override val queryExecution: QueryExecution) extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure @@ -282,6 +282,35 @@ class DataFrame private[sql]( */ def schema: StructType = queryExecution.analyzed.schema + /** + * Prints the schema to the console in a nice tree format. + * @group basic + * @since 1.3.0 + */ + // scalastyle:off println + override def printSchema(): Unit = println(schema.treeString) + // scalastyle:on println + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @group basic + * @since 1.3.0 + */ + override def explain(extended: Boolean): Unit = { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + // scalastyle:off println + r => println(r.getString(0)) + // scalastyle:on println + } + } + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.3.0 + */ + override def explain(): Unit = explain(extended = false) + /** * Returns all column names and their data types as an array. * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 97eb5b969280..da4600133290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -61,8 +61,8 @@ import org.apache.spark.util.Utils */ @Experimental class Dataset[T] private[sql]( - @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution, + @transient override val sqlContext: SQLContext, + @transient override val queryExecution: QueryExecution, tEncoder: Encoder[T]) extends Queryable with Serializable { /** @@ -85,7 +85,25 @@ class Dataset[T] private[sql]( * Returns the schema of the encoded form of the objects in this [[Dataset]]. * @since 1.6.0 */ - def schema: StructType = resolvedTEncoder.schema + override def schema: StructType = resolvedTEncoder.schema + + /** + * Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format. + * @since 1.6.0 + */ + override def printSchema(): Unit = toDF().printSchema() + + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(extended: Boolean): Unit = toDF().explain(extended) + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 1.6.0 + */ + override def explain(): Unit = toDF().explain() /* ************* * * Conversions * @@ -152,6 +170,59 @@ class Dataset[T] private[sql]( */ def count(): Long = toDF().count() + /** + * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * + * @since 1.6.0 + */ + def show(numRows: Int): Unit = show(numRows, truncate = true) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. + * + * @since 1.6.0 + */ + def show(): Unit = show(20) + + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @since 1.6.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = toDF().show(numRows, truncate) + /** * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * @since 1.6.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index 321e2c783537..f2f5997d1b7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution +import scala.util.control.NonFatal + import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType -import scala.util.control.NonFatal - /** A trait that holds shared code between DataFrames and Datasets. */ private[sql] trait Queryable { def schema: StructType @@ -37,31 +37,9 @@ private[sql] trait Queryable { } } - /** - * Prints the schema to the console in a nice tree format. - * @group basic - * @since 1.3.0 - */ - // scalastyle:off println - def printSchema(): Unit = println(schema.treeString) - // scalastyle:on println + def printSchema(): Unit - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(extended: Boolean): Unit = { - val explain = ExplainCommand(queryExecution.logical, extended = extended) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { - // scalastyle:off println - r => println(r.getString(0)) - // scalastyle:on println - } - } + def explain(extended: Boolean): Unit - /** - * Only prints the physical plan to the console for debugging purposes. - * @since 1.3.0 - */ - def explain(): Unit = explain(extended = false) + def explain(): Unit } From 63850026576b3ea7783f9d4b975171dc3cff6e4c Mon Sep 17 00:00:00 2001 From: Ashwin Swaroop Date: Wed, 25 Nov 2015 13:41:14 +0000 Subject: [PATCH 891/896] [SPARK-11686][CORE] Issue WARN when dynamic allocation is disabled due to spark.dynamicAllocation.enabled and spark.executor.instances both set Changed the log type to a 'warning' instead of 'info' as required. Author: Ashwin Swaroop Closes #9926 from ashwinswaroop/master. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e19ba113702c..2c10779f2b89 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -556,7 +556,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Optionally scale number of executors dynamically based on workload. Exposed for testing. val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf) if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) { - logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") + logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") } _executorAllocationManager = From b9b6fbe89b6d1a890faa02c1a53bb670a6255362 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 25 Nov 2015 13:49:58 +0000 Subject: [PATCH 892/896] =?UTF-8?q?[SPARK-11860][PYSAPRK][DOCUMENTATION]?= =?UTF-8?q?=20Invalid=20argument=20specification=20=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …for registerFunction [Python] Straightforward change on the python doc Author: Jeff Zhang Closes #9901 from zjffdu/SPARK-11860. --- python/pyspark/sql/context.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5a85ac31025e..a49c1b58d018 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -195,14 +195,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a lambda function as a UDF so it can be used in SQL statements. + """Registers a python function (including lambda function) as a UDF + so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param samplingRatio: lambda function + :param f: python function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) From 0a5aef753e70e93d7e56054f354a52e4d4e18932 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 25 Nov 2015 09:34:34 -0600 Subject: [PATCH 893/896] [SPARK-10666][SPARK-6880][CORE] Use properties from ActiveJob associated with a Stage This issue was addressed in https://github.com/apache/spark/pull/5494, but the fix in that PR, while safe in the sense that it will prevent the SparkContext from shutting down, misses the actual bug. The intent of `submitMissingTasks` should be understood as "submit the Tasks that are missing for the Stage, and run them as part of the ActiveJob identified by jobId". Because of a long-standing bug, the `jobId` parameter was never being used. Instead, we were trying to use the jobId with which the Stage was created -- which may no longer exist as an ActiveJob, hence the crash reported in SPARK-6880. The correct fix is to use the ActiveJob specified by the supplied jobId parameter, which is guaranteed to exist at the call sites of submitMissingTasks. This fix should be applied to all maintenance branches, since it has existed since 1.0. kayousterhout pankajarora12 Author: Mark Hamstra Author: Imran Rashid Closes #6291 from markhamstra/SPARK-6880. --- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../spark/scheduler/DAGSchedulerSuite.scala | 107 +++++++++++++++++- 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 77a184dfe4be..e01a9609b9a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -946,7 +946,9 @@ class DAGScheduler( stage.resetInternalAccumulators() } - val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull + // Use the scheduling pool, job group, description, etc. from an ActiveJob associated + // with this Stage + val properties = jobIdToActiveJob(jobId).properties runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -1047,7 +1049,7 @@ class DAGScheduler( stage.pendingPartitions ++= tasks.map(_.partitionId) logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4d6b25455226..653d41fc053c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.Properties + import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -262,9 +264,10 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - listener: JobListener = jobListener): Int = { + listener: JobListener = jobListener, + properties: Properties = null): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, properties)) jobId } @@ -1322,6 +1325,106 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + def checkJobPropertiesAndPriority(taskSet: TaskSet, expected: String, priority: Int): Unit = { + assert(taskSet.properties != null) + assert(taskSet.properties.getProperty("testProperty") === expected) + assert(taskSet.priority === priority) + } + + def launchJobsThatShareStageAndCancelFirst(): ShuffleDependency[Int, Int, Nothing] = { + val baseRdd = new MyRDD(sc, 1, Nil) + val shuffleDep1 = new ShuffleDependency(baseRdd, new HashPartitioner(1)) + val intermediateRdd = new MyRDD(sc, 1, List(shuffleDep1)) + val shuffleDep2 = new ShuffleDependency(intermediateRdd, new HashPartitioner(1)) + val finalRdd1 = new MyRDD(sc, 1, List(shuffleDep2)) + val finalRdd2 = new MyRDD(sc, 1, List(shuffleDep2)) + val job1Properties = new Properties() + val job2Properties = new Properties() + job1Properties.setProperty("testProperty", "job1") + job2Properties.setProperty("testProperty", "job2") + + // Run jobs 1 & 2, both referencing the same stage, then cancel job1. + // Note that we have to submit job2 before we cancel job1 to have them actually share + // *Stages*, and not just shuffle dependencies, due to skipped stages (at least until + // we address SPARK-10193.) + val jobId1 = submit(finalRdd1, Array(0), properties = job1Properties) + val jobId2 = submit(finalRdd2, Array(0), properties = job2Properties) + assert(scheduler.activeJobs.nonEmpty) + val testProperty1 = scheduler.jobIdToActiveJob(jobId1).properties.getProperty("testProperty") + + // remove job1 as an ActiveJob + cancel(jobId1) + + // job2 should still be running + assert(scheduler.activeJobs.nonEmpty) + val testProperty2 = scheduler.jobIdToActiveJob(jobId2).properties.getProperty("testProperty") + assert(testProperty1 != testProperty2) + // NB: This next assert isn't necessarily the "desired" behavior; it's just to document + // the current behavior. We've already submitted the TaskSet for stage 0 based on job1, but + // even though we have cancelled that job and are now running it because of job2, we haven't + // updated the TaskSet's properties. Changing the properties to "job2" is likely the more + // correct behavior. + val job1Id = 0 // TaskSet priority for Stages run with "job1" as the ActiveJob + checkJobPropertiesAndPriority(taskSets(0), "job1", job1Id) + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + + shuffleDep1 + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active + */ + test("stage used by two jobs, the first no longer active (SPARK-6880)") { + launchJobsThatShareStageAndCancelFirst() + + // The next check is the key for SPARK-6880. For the stage which was shared by both job1 and + // job2 but never had any tasks submitted for job1, the properties of job2 are now used to run + // the stage. + checkJobPropertiesAndPriority(taskSets(1), "job2", 1) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + assert(taskSets(2).properties != null) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + + /** + * Makes sure that tasks for a stage used by multiple jobs are submitted with the properties of a + * later, active job if they were previously run under a job that is no longer active, even when + * there are fetch failures + */ + test("stage used by two jobs, some fetch failures, and the first job no longer active " + + "(SPARK-6880)") { + val shuffleDep1 = launchJobsThatShareStageAndCancelFirst() + val job2Id = 1 // TaskSet priority for Stages run with "job2" as the ActiveJob + + // lets say there is a fetch failure in this task set, which makes us go back and + // run stage 0, attempt 1 + complete(taskSets(1), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // stage 0, attempt 1 should have the properties of job2 + assert(taskSets(2).stageId === 0) + assert(taskSets(2).stageAttemptId === 1) + checkJobPropertiesAndPriority(taskSets(2), "job2", job2Id) + + // run the rest of the stages normally, checking that they have the correct properties + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(3), "job2", job2Id) + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 1)))) + checkJobPropertiesAndPriority(taskSets(4), "job2", job2Id) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assert(scheduler.activeJobs.isEmpty) + + assertDataStructuresEmpty() + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) From c1f85fc71e71e07534b89c84677d977bb20994f8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Nov 2015 09:47:20 -0800 Subject: [PATCH 894/896] [SPARK-11956][CORE] Fix a few bugs in network lib-based file transfer. - NettyRpcEnv::openStream() now correctly propagates errors to the read side of the pipe. - NettyStreamManager now throws if the file being transferred does not exist. - The network library now correctly handles zero-sized streams. Author: Marcelo Vanzin Closes #9941 from vanzin/SPARK-11956. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 19 +++++++++---- .../spark/rpc/netty/NettyStreamManager.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 27 +++++++++++++----- .../client/TransportResponseHandler.java | 28 ++++++++++++------- .../org/apache/spark/network/StreamSuite.java | 23 ++++++++++++++- 5 files changed, 75 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68701f609f77..c8fa870f50e6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -27,7 +27,7 @@ import javax.annotation.Nullable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} +import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv( @volatile private var error: Throwable = _ - def setError(e: Throwable): Unit = error = e + def setError(e: Throwable): Unit = { + error = e + source.close() + } override def read(dst: ByteBuffer): Int = { - if (error != null) { - throw error + val result = if (error == null) { + Try(source.read(dst)) + } else { + Failure(error) + } + + result match { + case Success(bytesRead) => bytesRead + case Failure(error) => throw error } - source.read(dst) } override def close(): Unit = source.close() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index eb1d2604fb23..a2768b4252dc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") } - require(file != null, s"File not found: $streamId") + require(file != null && file.isFile(), s"File not found: $streamId") new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2b664c6313ef..6cc958a5f6bc 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) val fileUri = env.fileServer.addFile(file) + val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val destDir = Utils.createTempDir() - val destFile = new File(destDir, file.getName()) - val destJar = new File(destDir, jar.getName()) - val sm = new SecurityManager(conf) val hc = SparkHadoopUtil.get.conf - Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) - Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) - assert(Files.equal(file, destFile)) - assert(Files.equal(jar, destJar)) + val files = Seq( + (file, fileUri), + (empty, emptyUri), + (jar, jarUri)) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be181e066082..4c15045363b8 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -185,16 +185,24 @@ public void handle(ResponseMessage message) { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } } } else { logger.error("Could not find callback for StreamResponse."); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 00158fd08162..538f3efe8d6f 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,13 +51,14 @@ import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; private static TransportServer server; private static TransportClientFactory clientFactory; private static File testFile; private static File tempDir; + private static ByteBuffer emptyBuffer; private static ByteBuffer smallBuffer; private static ByteBuffer largeBuffer; @@ -73,6 +74,7 @@ private static ByteBuffer createBuffer(int bufSize) { @BeforeClass public static void setUp() throws Exception { tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); smallBuffer = createBuffer(100); largeBuffer = createBuffer(100000); @@ -103,6 +105,8 @@ public ManagedBuffer openStream(String streamId) { return new NioManagedBuffer(largeBuffer); case "smallBuffer": return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); case "file": return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); default: @@ -138,6 +142,18 @@ public static void tearDown() { } } + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + @Test public void testSingleStream() throws Throwable { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -226,6 +242,11 @@ public void run() { outFile = File.createTempFile("data", ".tmp", tempDir); out = new FileOutputStream(outFile); break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; default: throw new IllegalArgumentException(streamId); } From faabdfa2bd416ae514961535f1953e8e9e8b1f3f Mon Sep 17 00:00:00 2001 From: felixcheung Date: Wed, 25 Nov 2015 10:36:35 -0800 Subject: [PATCH 895/896] [SPARK-11984][SQL][PYTHON] Fix typos in doc for pivot for scala and python Author: felixcheung Closes #9967 from felixcheung/pypivotdoc. --- python/pyspark/sql/group.py | 6 +++--- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index d8ed7eb2dda6..1911588309af 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -169,11 +169,11 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): - """Pivots a column of the current DataFrame and preform the specified aggregation. + """Pivots a column of the current DataFrame and perform the specified aggregation. :param pivot_col: Column to pivot - :param values: Optional list of values of pivotColumn that will be translated to columns in - the output data frame. If values are not provided the method with do an immediate call + :param values: Optional list of values of pivot column that will be translated to columns in + the output DataFrame. If values are not provided the method will do an immediate call to .distinct() on the pivot column. >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index abd531c4ba54..13341a88a6b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -282,7 +282,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -321,7 +321,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. @@ -353,7 +353,7 @@ class GroupedData protected[sql]( } /** - * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * Pivots a column of the current [[DataFrame]] and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. From 6b781576a15d8d5c5fbed8bef1c5bda95b3d44ac Mon Sep 17 00:00:00 2001 From: Zhongshuai Pei Date: Wed, 25 Nov 2015 10:37:34 -0800 Subject: [PATCH 896/896] [SPARK-11974][CORE] Not all the temp dirs had been deleted when the JVM exits deleting the temp dir like that ``` scala> import scala.collection.mutable import scala.collection.mutable scala> val a = mutable.Set(1,2,3,4,7,0,8,98,9) a: scala.collection.mutable.Set[Int] = Set(0, 9, 1, 2, 3, 7, 4, 8, 98) scala> a.foreach(x => {a.remove(x) }) scala> a.foreach(println(_)) 98 ``` You may not modify a collection while traversing or iterating over it.This can not delete all element of the collection Author: Zhongshuai Pei Closes #9951 from DoingDone9/Bug_RemainDir. --- .../scala/org/apache/spark/util/ShutdownHookManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index db4a8b304ec3..4012dca3ecdf 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -57,7 +57,9 @@ private[spark] object ShutdownHookManager extends Logging { // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => logInfo("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => + // we need to materialize the paths to delete because deleteRecursively removes items from + // shutdownDeletePaths as we are traversing through it. + shutdownDeletePaths.toArray.foreach { dirPath => try { logInfo("Deleting directory " + dirPath) Utils.deleteRecursively(new File(dirPath))