Skip to content
Prev Previous commit
Next Next commit
Add wrapper function: predict & fitted
  • Loading branch information
yanboliang committed Jan 11, 2016
commit f1485ca620395810ac25300c4dd0b17ecac56031
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ private[ml] abstract class Family(val link: Link) extends Serializable {

/** The working dependent variable. */
def z(y: Double, mu: Double, eta: Double): Double

/** Linear predictors based on given mu. */
def predict(mu: Double): Double = this.link.link(mu)

/** Fitted values based on linear predictors eta. */
def fitted(eta: Double): Double = this.link.unlink(eta)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[ml] class IterativelyReweightedLeastSquares(
while (iter < maxIter && !converged) {

zw = y.zip(mu).map { case (y, mu) =>
val eta = family.link.link(mu)
val eta = family.predict(mu)
val z = family.z(y, mu, eta)
val w = family.weights(mu)
(z, w)
Expand All @@ -80,7 +80,7 @@ private[ml] class IterativelyReweightedLeastSquares(
eta = newInstances.map { instance =>
dot(instance.features, model.coefficients) + model.intercept
}
mu = eta.map { mu => family.link.unlink(mu) }
mu = eta.map { eta => family.fitted(eta) }

oldDev = dev
dev = family.deviance(y, mu)
Expand Down