Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.mllib.feature

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] (

require(mean.size == variance.size)

private lazy val factor: BDV[Double] = {
val f = BDV.zeros[Double](variance.size)
private lazy val factor: Array[Double] = {
val f = Array.ofDim[Double](variance.size)
var i = 0
while (i < f.size) {
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
Expand All @@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] (
f
}

// Since `shift` will be only used in `withMean` branch, we have it as
// `lazy val` so it will be evaluated in that branch. Note that we don't
// want to create this array multiple times in `transform` function.
private lazy val shift: Array[Double] = mean.toArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need lazy here, because mean.toArray is not expensive.


/**
* Applies standardization transformation on a vector.
*
Expand All @@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] (
override def transform(vector: Vector): Vector = {
require(mean.size == vector.size)
if (withMean) {
vector.toBreeze match {
case dv: BDV[Double] =>
val output = vector.toBreeze.copy
var i = 0
while (i < output.length) {
output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0)
i += 1
// By default, Scala generates Java methods for member variables. So every time when
// the member variables are accessed, `invokespecial` will be called which is expensive.
// This can be avoid by having a local reference of `shift`.
val localShift = shift
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worth to leave a comment here and explain why we need local reference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift is only used in this branch. Shall we just put val shift = mean.toArray here instead of having a member variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I'll change it back to lazy since it will not be evaluated in those branches which don't use shift. I don't want to create shift array/object for each sample since shift will always be the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift only holds a reference to mean.values. We don't really need to define it as a member and make it lazy. It should give the same performance if we only define it inside the if branch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For different implementation of vector, toArray can be very expensive. For example, toArray for sparse vector requires to create a new array object and loop through all the non zero values. As a result, we can have a global lazy shift which can prevent this happens.

vector match {
case dv: DenseVector =>
val values = dv.values.clone()
val size = values.size
if (withStd) {
// Having a local reference of `factor` to avoid overhead as the comment before.
val localFactor = factor
var i = 0
while (i < size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we move var i = 0 inside each closure? It feels safer.

values(i) = (values(i) - localShift(i)) * localFactor(i)
i += 1
}
} else {
var i = 0
while (i < size) {
values(i) -= localShift(i)
i += 1
}
}
Vectors.fromBreeze(output)
Vectors.dense(values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else if (withStd) {
vector.toBreeze match {
case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor)
case sv: BSV[Double] =>
// Having a local reference of `factor` to avoid overhead as the comment before.
val localFactor = factor
vector match {
case dv: DenseVector =>
val values = dv.values.clone()
val size = values.size
var i = 0
while(i < size) {
values(i) *= localFactor(i)
i += 1
}
Vectors.dense(values)
case sv: SparseVector =>
// For sparse vector, the `index` array inside sparse vector object will not be changed,
// so we can re-use it to save memory.
val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
val indices = sv.indices
val values = sv.values.clone()
val nnz = values.size
var i = 0
while (i < output.data.length) {
output.data(i) *= factor(output.index(i))
while (i < nnz) {
values(i) *= localFactor(indices(i))
i += 1
}
Vectors.fromBreeze(output)
Vectors.sparse(sv.size, indices, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {
Expand Down