Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
zhengruifeng committed Sep 4, 2017
commit 971e52c4a18b4261d82ac14fefa6bb849367562c
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,13 @@ class LogisticRegressionWithLBFGS
lr.setFitIntercept(addIntercept)
lr.setMaxIter(optimizer.getNumIterations())
lr.setTol(optimizer.getConvergenceTol())
// Determine if we should cache the DF
lr.setHandlePersistence(input.getStorageLevel == StorageLevel.NONE)
// Convert our input into a DataFrame
val spark = SparkSession.builder().sparkContext(input.context).getOrCreate()
val df = spark.createDataFrame(input.map(_.asML))
// Determine if we should cache the DF
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
// Train our model
val mlLogisticRegressionModel = lr.train(df, handlePersistence)
val mlLogisticRegressionModel = lr.train(df)
// convert the model
val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray)
createModel(weights, mlLogisticRegressionModel.intercept)
Expand Down