Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Simplify applyGradients
  • Loading branch information
juliabeliaeva committed Dec 24, 2022
commit 0191501c9dc56eacea966e1655ff5e8060d2c51f
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ public class AdaDelta(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

rhoConst = tf.constant(rho, getDType())
learningRateConst = tf.constant(learningRate, getDType())
epsilonConstant = tf.constant(epsilon, getDType())

for (i in weights.indices) {
val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val accumSlot: Variable<Float> = getSlot(varName, ACCUMULATOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ public class AdaGrad(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

initialAccumulatorValueConstant = tf.constant(initialAccumulatorValue, getDType())
learningRateConst = tf.constant(learningRate, getDType())

for (i in weights.indices) {
val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val slot: Variable<Float> = getSlot(varName, ACCUMULATOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,13 @@ public class AdaGradDA(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

learningRateConst = tf.constant(learningRate, getDType())
l1StrengthConst = tf.constant(l1Strength, getDType())
l2StrengthConst = tf.constant(l2Strength, getDType())

for (i in weights.indices) {
val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val gradSlot: Variable<Float> = getSlot(varName, ACCUMULATOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,14 @@ public class Adam(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

betaOneConst = tf.constant(beta1, getDType())
betaTwoConst = tf.constant(beta2, getDType())
learningRateConst = tf.constant(learningRate, getDType())
epsilonConstant = tf.constant(epsilon, getDType())

for (i in weights.indices) {

val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val firstMomentSlot: Variable<Float> = getSlot(varName, FIRST_MOMENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ public class Adamax(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

betaOneConst = tf.constant(beta1, getDType())
betaTwoConst = tf.constant(beta2, getDType())
Expand All @@ -82,8 +81,7 @@ public class Adamax(

val scope = Scope(graph.tfGraph)

for (i in weights.indices) {
val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val firstMomentSlot: Variable<Float> = getSlot(varName, FIRST_MOMENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,15 @@ public class Ftrl(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

l1RegularizationStrengthConst = tf.constant(l1RegularizationStrength, getDType())
l2RegularizationStrengthConst = tf.constant(l2RegularizationStrength, getDType())
learningRateConst = tf.constant(learningRate, getDType())
l2ShrinkageRegularizationStrengthConst = tf.constant(l2ShrinkageRegularizationStrength, getDType())
learningRatePowerConst = tf.constant(learningRatePower, getDType())

for (i in weights.indices) {

val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val accumSlot: Variable<Float> = getSlot(varName, ACCUMULATOR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,12 @@ public class Momentum(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

learningRateConst = tf.constant(learningRate)
momentumConst = tf.constant(momentum)

for (i in weights.indices) {
val variable = weights[i]

for ((i, variable) in weights.withIndex()) {
val slot = getSlot(variable.ref().op().name(), MOMENTUM)

targets.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,14 @@ public class RMSProp(
weights: List<Variable<Float>>,
gradients: Gradients
): List<Operand<Float>> {
val targets: MutableList<Operand<Float>> =
ArrayList()
val targets = mutableListOf<Operand<Float>>()

decayConst = tf.constant(decay, getDType())
momentumConst = tf.constant(momentum, getDType())
learningRateConst = tf.constant(learningRate, getDType())
epsilonConstant = tf.constant(epsilon, getDType())

for (i in weights.indices) {
val variable = weights[i]
for ((i, variable) in weights.withIndex()) {
val varName = variable.ref().op().name()

val rmsSlot: Variable<Float> = getSlot(varName, RMS)
Expand Down