@@ -32,7 +32,39 @@ import org.apache.spark.storage.StorageLevel
3232 */
3333private [classification] trait LogisticRegressionParams extends Params
3434 with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
35- with HasScoreCol with HasPredictionCol
35+ with HasScoreCol with HasPredictionCol {
36+
37+ /**
38+ * Validates and transforms the input schema with the provided param map.
39+ * @param schema input schema
40+ * @param paramMap additional parameters
41+ * @param fitting whether this is in fitting
42+ * @return output schema
43+ */
44+ protected def transformSchema (
45+ schema : StructType ,
46+ paramMap : ParamMap ,
47+ fitting : Boolean ): StructType = {
48+ val map = this .paramMap ++ paramMap
49+ val featuresType = schema(map(featuresCol)).dataType
50+ // TODO: Support casting Array[Double] and Array[Float] to Vector.
51+ require(featuresType.isInstanceOf [VectorUDT ],
52+ s " Features column ${map(featuresCol)} must be a vector column but got $featuresType. " )
53+ if (fitting) {
54+ val labelType = schema(map(labelCol)).dataType
55+ require(labelType == DoubleType ,
56+ s " Cannot convert label column ${map(labelCol)} of type $labelType to a double column. " )
57+ }
58+ val fieldNames = schema.fieldNames
59+ require(! fieldNames.contains(map(scoreCol)), s " Score column ${map(scoreCol)} already exists. " )
60+ require(! fieldNames.contains(map(predictionCol)),
61+ s " Prediction column ${map(predictionCol)} already exists. " )
62+ val outputFields = schema.fields ++ Seq (
63+ StructField (map(scoreCol), DoubleType , false ),
64+ StructField (map(predictionCol), DoubleType , false ))
65+ StructType (outputFields)
66+ }
67+ }
3668
3769/**
3870 * Logistic regression.
@@ -71,22 +103,7 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
71103 }
72104
73105 override def transform (schema : StructType , paramMap : ParamMap ): StructType = {
74- val map = this .paramMap ++ paramMap
75- val featuresType = schema(map(featuresCol)).dataType
76- // TODO: Support casting Array[Double] and Array[Float] to Vector.
77- require(featuresType.isInstanceOf [VectorUDT ],
78- s " Features column ${map(featuresCol)} must be a vector column but got $featuresType. " )
79- val labelType = schema(map(labelCol)).dataType
80- require(labelType == DoubleType ,
81- s " Cannot convert label column ${map(labelCol)} of type $labelType to a double column. " )
82- val fieldNames = schema.fieldNames
83- require(! fieldNames.contains(map(scoreCol)), s " Score column ${map(scoreCol)} already exists. " )
84- require(! fieldNames.contains(map(predictionCol)),
85- s " Prediction column ${map(predictionCol)} already exists. " )
86- val outputFields = schema.fields ++ Seq (
87- StructField (map(scoreCol), DoubleType , false ),
88- StructField (map(predictionCol), DoubleType , false ))
89- StructType (outputFields)
106+ transformSchema(schema, paramMap, fitting = true )
90107 }
91108}
92109
@@ -104,19 +121,7 @@ class LogisticRegressionModel private[ml] (
104121 def setPredictionCol (value : String ): this .type = { set(predictionCol, value); this }
105122
106123 override def transform (schema : StructType , paramMap : ParamMap ): StructType = {
107- val map = this .paramMap ++ paramMap
108- val featuresType = schema(map(featuresCol)).dataType
109- // TODO: Support casting Array[Double] and Array[Float] to Vector.
110- require(featuresType.isInstanceOf [VectorUDT ],
111- s " Features column ${map(featuresCol)} must be a vector column but got $featuresType. " )
112- val fieldNames = schema.fieldNames
113- require(! fieldNames.contains(map(scoreCol)), s " Score column ${map(scoreCol)} already exists. " )
114- require(! fieldNames.contains(map(predictionCol)),
115- s " Prediction column ${map(predictionCol)} already exists. " )
116- val outputFields = schema.fields ++ Seq (
117- StructField (map(scoreCol), DoubleType , false ),
118- StructField (map(predictionCol), DoubleType , false ))
119- StructType (outputFields)
124+ transformSchema(schema, paramMap, fitting = false )
120125 }
121126
122127 override def transform (dataset : SchemaRDD , paramMap : ParamMap ): SchemaRDD = {
0 commit comments