Skip to content

Commit 49600cc

Browse files
committed
address comments
1 parent 429ff7d commit 49600cc

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ class IsotonicRegressionModel (
133133
}
134134

135135
override def save(sc: SparkContext, path: String): Unit = {
136-
val intervals = boundaries.toList.zip(predictions.toList).toArray
137-
val data = IsotonicRegressionModel.SaveLoadV1_0.Data(intervals)
138-
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, data, isotonic)
136+
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
139137
}
140138

141139
override protected def formatVersion: String = "1.0"
@@ -153,9 +151,14 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
153151
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"
154152

155153
/** Model data for model import/export */
156-
case class Data(intervals: Array[(Double, Double)])
157-
158-
def save(sc: SparkContext, path: String, data: Data, isotonic: Boolean): Unit = {
154+
case class Data(boundary: Double, prediction: Double)
155+
156+
def save(
157+
sc: SparkContext,
158+
path: String,
159+
boundaries: Array[Double],
160+
predictions: Array[Double],
161+
isotonic: Boolean): Unit = {
159162
val sqlContext = new SQLContext(sc)
160163
import sqlContext.implicits._
161164

@@ -164,21 +167,18 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
164167
("isotonic" -> isotonic)))
165168
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
166169

167-
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
168-
dataRDD.saveAsParquetFile(dataPath(path))
170+
sqlContext.createDataFrame(boundaries.toList.zip(predictions.toList)
171+
.map { case (b, p) => Data(b, p) }).saveAsParquetFile(dataPath(path))
169172
}
170173

171174
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
172175
val sqlContext = new SQLContext(sc)
173176
val dataRDD = sqlContext.parquetFile(dataPath(path))
174177

175178
checkSchema[Data](dataRDD.schema)
176-
val dataArray = dataRDD.select("intervals").take(1)
177-
assert(dataArray.size == 1,
178-
s"Unable to load IsotonicRegressionModel data from: ${dataPath(path)}")
179-
val data = dataArray(0)
180-
val intervals = data.getAs[Seq[(Double, Double)]](0)
181-
val (boundaries, predictions) = intervals.unzip
179+
val dataArray = dataRDD.select("boundary", "prediction").collect()
180+
val (boundaries, predictions) = dataArray.map {
181+
x => (x.getAs[Double](0), x.getAs[Double](1)) }.toList.unzip
182182
(boundaries.toArray, predictions.toArray)
183183
}
184184
}

0 commit comments

Comments
 (0)