Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b69f201
Added tunable parallelism to the pyspark implementation of one vs. re…
ajaysaini725 Jun 12, 2017
e750d3e
Fixed python style.
ajaysaini725 Jun 12, 2017
81d458b
Added functionality for tuning parellelism in the Scala implementatio…
ajaysaini725 Jun 13, 2017
2133378
Fixed code according to comments. Added both annotations and unit tes…
ajaysaini725 Jun 13, 2017
c59b1d8
Modified parallel one vs rest to use futures.
ajaysaini725 Jun 22, 2017
5f635a2
Put the parallelism parameter as well as the function for getting an …
ajaysaini725 Jun 23, 2017
4431ffc
Responded to pull request comments.
ajaysaini725 Jun 23, 2017
a841b3e
Made changes based on pull request comments.
ajaysaini725 Jul 6, 2017
a95a8af
Fixed based on pull request comments
ajaysaini725 Jul 14, 2017
d45bc23
Fixed based on comments
ajaysaini725 Jul 18, 2017
30ac62d
Reverting merge and adding change that would fix merge conflict (maki…
ajaysaini725 Jul 19, 2017
cc634d2
Merge branch 'master' into spark-21027
ajaysaini725 Jul 19, 2017
ce14172
Style fix with docstring
ajaysaini725 Jul 20, 2017
1c9de16
Fixed based on comments.
ajaysaini725 Jul 27, 2017
9f34404
Fixed style issue.
ajaysaini725 Jul 27, 2017
585a3f8
Fixed merge conflict
ajaysaini725 Aug 12, 2017
f65381a
Fixed remaining part of merge conflict.
ajaysaini725 Aug 23, 2017
2a335fe
Fixed style problem
ajaysaini725 Aug 23, 2017
049f371
Merge branch 'master' into spark-21027
WeichenXu123 Sep 2, 2017
ddc2ff4
address review feedback issues
WeichenXu123 Sep 3, 2017
fc6fd5e
update migration guide
WeichenXu123 Sep 3, 2017
7d0849e
update desc
WeichenXu123 Sep 6, 2017
edcf85c
fix style
WeichenXu123 Sep 6, 2017
7a1d404
merge master & resolve conflicts
WeichenXu123 Sep 6, 2017
c24d4e2
update out-of-date shared.py
WeichenXu123 Sep 12, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Modified parallel one vs rest to use futures.
  • Loading branch information
ajaysaini725 committed Jun 22, 2017
commit c59b1d897c24d88753f478243ac8428598108da3
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@

package org.apache.spark.ml.classification

import java.util.{List => JList}
import java.util.UUID
import java.util.concurrent.ExecutorService

import scala.collection.JavaConverters._
import scala.collection.parallel.ForkJoinTaskSupport
import scala.concurrent.forkjoin.ForkJoinPool
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.language.existentials

import com.google.common.util.concurrent.MoreExecutors
import org.apache.hadoop.fs.Path
import org.json4s.{DefaultFormats, JObject, _}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
Expand All @@ -40,7 +39,9 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.SparkContext
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ThreadUtils

private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
Expand Down Expand Up @@ -286,7 +287,7 @@ final class OneVsRest @Since("1.4.0") (
"the number of processes to use when running parallel one vs. rest", ParamValidators.gtEq(1))

setDefault(
parallelism -> 4
parallelism -> 1
)

@Since("1.4.0")
Expand Down Expand Up @@ -324,6 +325,14 @@ final class OneVsRest @Since("1.4.0") (
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}

def getExecutorService: ExecutorService = {
if (getParallelism == 1) {
return MoreExecutors.sameThreadExecutor()
}
ThreadUtils
.newDaemonCachedThreadPool(s"${this.getClass.getSimpleName}-thread-pool", getParallelism)
}

@Since("2.0.0")
override def fit(dataset: Dataset[_]): OneVsRestModel = {
transformSchema(dataset.schema)
Expand All @@ -350,25 +359,28 @@ final class OneVsRest @Since("1.4.0") (
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
}

val iters = Range(0, numClasses).par
iters.tasksupport = new ForkJoinTaskSupport(
new ForkJoinPool(Math.min(getParallelism, numClasses))
)
val executor = getExecutorService
val executionContext = ExecutionContext.fromExecutorService(executor)

// create k columns, one for each binary classifier.
val models = iters.map { index =>
// generate new label metadata for the binary problem.
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val trainingDataset = multiclassLabeled.withColumn(
labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
}.toArray[ClassificationModel[_, _]]
val modelFutures = Range(0, numClasses).map { index =>
Future[ClassificationModel[_, _]] {
// generate new label metadata for the binary problem.
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
val trainingDataset = multiclassLabeled.withColumn(
labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
}(executionContext)
}
val models = modelFutures
.map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]]

instr.logNumFeatures(models.head.numFeatures)

if (handlePersistence) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.util

import org.apache.spark.ml.param.Params

/**
* Common parameter for estimators trained in a multithreaded environment.
*/
private[ml] trait ParallelismParam extends Params {


}
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,14 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau

test("one-vs-rest: tuning parallelism does not change output") {
val numClasses = 3
val ovaPar2 = new OneVsRest()
val ovaPar1 = new OneVsRest()
.setClassifier(new LogisticRegression)
.setParallelism(2)

val ovaModelPar2 = ovaPar2.fit(dataset)
val ovaModelPar1 = ovaPar1.fit(dataset)

val transformedDatasetPar2 = ovaModelPar2.transform(dataset)
val transformedDatasetPar1 = ovaModelPar1.transform(dataset)

val ovaResultsPar2 = transformedDatasetPar2.select("prediction", "label").rdd.map {
val ovaResultsPar1 = transformedDatasetPar1.select("prediction", "label").rdd.map {
row => (row.getDouble(0), row.getDouble(1))
}

Expand All @@ -127,9 +126,9 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
row => (row.getDouble(0), row.getDouble(1))
}

val metricsPar2 = new MulticlassMetrics(ovaResultsPar2)
val metricsPar1 = new MulticlassMetrics(ovaResultsPar1)
val metricsPar4 = new MulticlassMetrics(ovaResultsPar4)
assert(metricsPar2.confusionMatrix ~== metricsPar4.confusionMatrix absTol 400)
assert(metricsPar1.confusionMatrix ~== metricsPar4.confusionMatrix absTol 400)
}

test("one-vs-rest: pass label metadata correctly during train") {
Expand Down