Skip to content
Prev Previous commit
Next Next commit
still workin
  • Loading branch information
jkbradley committed Apr 29, 2015
commit a910ac7fb6966bf09be2dc71e2a9eb45820d4f58
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.Params$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
Expand Down Expand Up @@ -102,7 +103,7 @@ public static void main(String[] args) throws Exception {
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
class MyJavaLogisticRegression
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> implements Params {

/**
* Param for max number of iterations
Expand Down Expand Up @@ -146,7 +147,7 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap)
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
class MyJavaLogisticRegressionModel
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {

private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }
Expand Down
91 changes: 47 additions & 44 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.ml.param
import java.lang.reflect.Modifier
import java.util.NoSuchElementException

import scala.collection.mutable
import scala.annotation.varargs
import scala.collection.mutable

import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -90,45 +90,50 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
}
}

/** Factory methods for common validation functions for [[Param.isValid]] */
/**
* Factory methods for common validation functions for [[Param.isValid]].
* The numerical methods only support Int, Long, Float, and Double.
*/
object ParamValidate {

/** Default validation always return true */
def default[T]: T => Boolean = (_: T) => true

/** Negate the given check */
def not[T](isValid: T => Boolean): T => Boolean = { (value: T) =>
!isValid(value)
/**
* Private method for checking numerical types and converting to Double.
* This is mainly for the sake of compilation; type checks are really handled
* by [[Params]] setters and the [[ParamPair]] constructor.
*/
private def getDouble[T](value: T): Double = value match {
case x: Int => x.toDouble
case x: Long => x.toDouble
case x: Float => x.toDouble
case x: Double => x.toDouble
case _ =>
// The type should be checked before this is ever called.
throw new IllegalArgumentException("Numerical Param validation failed because" +
s" of unexpected input type: ${value.getClass}")
}

/** Combine two checks */
def and[T](isValid1: T => Boolean, isValid2: T => Boolean): T => Boolean = { (value: T) =>
isValid1(value) && isValid2(value)
/** Check if value > lowerBound */
def gt[T](lowerBound: Double): T => Boolean = { (value: T) =>
getDouble(value) > lowerBound
}

/** Check for value > lowerBound. Use [[not()]] for <= check. */
def gt(lowerBound: Int): Int => Boolean = { (value: Int) => value > lowerBound }

/** Check for value >= lowerBound. Use [[not()]] for < check. */
def gtEq(lowerBound: Int): Int => Boolean = { (value: Int) => value >= lowerBound }

/** Check for value > lowerBound. Use [[not()]] for <= check. */
def gt(lowerBound: Long): Long => Boolean = { (value: Long) => value > lowerBound }

/** Check for value >= lowerBound. Use [[not()]] for < check. */
def gtEq(lowerBound: Long): Long => Boolean = { (value: Long) => value >= lowerBound }

/** Check for value > lowerBound. Use [[not()]] for <= check. */
def gt(lowerBound: Float): Float => Boolean = { (value: Float) => value > lowerBound }

/** Check for value >= lowerBound. Use [[not()]] for < check. */
def gtEq(lowerBound: Float): Float => Boolean = { (value: Float) => value >= lowerBound }
/** Check if value >= lowerBound */
def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) =>
getDouble(value) >= lowerBound
}

/** Check for value > lowerBound. Use [[not()]] for <= check. */
def gt(lowerBound: Double): Double => Boolean = { (value: Double) => value > lowerBound }
/** Check if value < upperBound */
def lt[T](upperBound: Double): T => Boolean = { (value: T) =>
getDouble(value) < upperBound
}

/** Check for value >= lowerBound. Use [[not()]] for < check. */
def gtEq(lowerBound: Double): Double => Boolean = { (value: Double) => value >= lowerBound }
/** Check if value <= upperBound */
def ltEq[T](upperBound: Double): T => Boolean = { (value: T) =>
getDouble(value) <= upperBound
}

/**
* Check for value in range lowerBound to upperBound.
Expand All @@ -137,33 +142,31 @@ object ParamValidate {
* @param upperInclusive If true, check for value <= upperBound.
* If false, check for value < upperBound.
*/
def inRange[T <: Comparable[T]](
lowerBound: T,
upperBound: T,
def inRange[T](
lowerBound: Double,
upperBound: Double,
lowerInclusive: Boolean,
upperInclusive: Boolean): T => Boolean = { (x: T) =>
val lowerValid = if (lowerInclusive) {
x.compareTo(lowerBound) >= 0
} else {
x.compareTo(lowerBound) > 0
}
val upperValid = if (upperInclusive) {
x.compareTo(upperBound) <= 0
} else {
x.compareTo(upperBound) < 0
}
upperInclusive: Boolean): T => Boolean = { (value: T) =>
val x: Double = getDouble(value)
val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound
val upperValid = if (upperInclusive) x <= upperBound else x < upperBound
lowerValid && upperValid
}

/** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */
def inRange[T](lowerBound: T, upperBound: T): T => Boolean = {
def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = {
inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true)
}

/** Check for value in an allowed set of values. */
def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) =>
allowed.contains(value)
}

/** Check for value in an allowed set of values. */
def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
allowed.contains(value)
}
}

// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,9 @@ private[shared] object SharedParamsCodeGen {

s"""
|/**
| * :: DeveloperApi ::
| * Trait for shared param $name$defaultValueDoc.
| * (private[ml]) Trait for shared param $name$defaultValueDoc.
| */
|@DeveloperApi
|trait Has$Name extends Params {
|private[ml] trait Has$Name extends Params {
|
| /**
| * Param for $doc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* Default: 10
* @group param
*/
val rank = new IntParam(this, "rank", "rank of the factorization",
isValid = ParamValidate.gtEq[Int](1))
val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidate.gtEq[Int](1))

/** @group getParam */
def getRank: Int = getOrDefault(rank)
Expand All @@ -68,7 +67,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @group param
*/
val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks",
isValid = ParamValidate.gtEq[Int](1))
ParamValidate.gtEq[Int](1))

/** @group getParam */
def getNumUserBlocks: Int = getOrDefault(numUserBlocks)
Expand All @@ -78,9 +77,8 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* Default: 10
* @group param
*/
val numItemBlocks =
new IntParam(this, "numItemBlocks", "number of item blocks",
isValid = ParamValidate.gtEq[Int](1))
val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks",
ParamValidate.gtEq[Int](1))

/** @group getParam */
def getNumItemBlocks: Int = getOrDefault(numItemBlocks)
Expand All @@ -101,7 +99,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @group param
*/
val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference",
isValid = ParamValidate.gtEq[Double](0))
ParamValidate.gtEq[Double](0))

/** @group getParam */
def getAlpha: Double = getOrDefault(alpha)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ public void tearDown() {
public void testParams() {
JavaTestParams testParams = new JavaTestParams();
Assert.assertEquals(testParams.getMyIntParam(), 1);
testParams.setMyIntParam(2).setMyDoubleParam(0.4);
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
}
}
34 changes: 22 additions & 12 deletions mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package org.apache.spark.ml.param;

import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.shared.HasMaxIter;
import java.util.List;

import com.google.common.collect.Lists;

/**
* A subclass of Params for testing.
Expand All @@ -10,26 +11,35 @@ public class JavaTestParams extends JavaParams {

public IntParam myIntParam;

public DoubleParam myDoubleParam;

public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }

public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }

public JavaTestParams setMyIntParam(int value) {
set(myIntParam, value); return this;
}

public DoubleParam myDoubleParam;

public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }

public JavaTestParams setMyDoubleParam(double value) {
set(myDoubleParam, value); return this;
}

public Param<String> myStringParam;

public String getMyStringParam() { return (String)getOrDefault(myStringParam); }

public JavaTestParams setMyStringParam(String value) {
set(myStringParam, value); return this;
}

public JavaTestParams() {
myIntParam =
new IntParam(this, "myIntParam", "this is an int param", ParamValidate.gt(0));
myDoubleParam =
new DoubleParam(this, "myDoubleParam", "this is a double param",
ParamValidate.and(ParamValidate.gtEq(0.0), ParamValidate.gt(1.0));
setDefault(myIntParam.w(1));
myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidate.gt(0));
myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
ParamValidate.inRange(0.0, 1.0));
List<String> validStrings = Lists.newArrayList("a", "b");
myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
ParamValidate.inArray(validStrings));
setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
}
}