Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
addressed comments
  • Loading branch information
brkyvz committed May 6, 2015
commit c81072dabdaaf9b9ce6fb08c764f302f639c273c
16 changes: 6 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,19 +220,15 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
}

/** Specialized version of [[Param[Array[T]]]] for Java. */
class ArrayParam[T : ClassTag](
parent: Params,
name: String,
doc: String,
isValid: Array[T] => Boolean)
extends Param[Array[T]](parent, name, doc, isValid) {
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value)
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray)
private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would renaming wCast -> w work? I think it should compile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It compiles but it doesn't work... The current state is the only way I got it to work.

}

/**
Expand Down Expand Up @@ -328,8 +324,8 @@ trait Params extends Identifiable with Serializable {
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) {
paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]]))
if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specialized for StringArrayParam. If user adds other Param types, they cannot modify the code. We can add set (or setAll) with varargs that takes ParamPair[_]*, and let users to create ParamPair first if the value type is different from the param type. For this function, paramMap.put(param.w(value)) should be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this going to help with the Python Api though? We don't have ParamPair in Python do we?

paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]]))
} else {
paramMap.put(param.w(value))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]"
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
case _ => s"Param[${getTypeString(c)}]"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names.
* @group param
*/
final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names")
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)
Expand Down