Skip to content

Commit d9108c7

Browse files
committed
Merge pull request #19 from nybbles/unsupervised-stats-fixes
Fixed stats/err reporting issues when batchsize != 1
2 parents 8d7dc02 + 6c734c8 commit d9108c7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

3_unsupervised/3_train.lua

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ for t = 1,params.maxiter,params.batchsize do
6363
-- progress
6464
--
6565
iter = iter+1
66-
xlua.progress(iter, params.statinterval)
66+
xlua.progress(iter*params.batchsize, params.statinterval)
6767

6868
--------------------------------------------------------------------
6969
-- create mini-batch
@@ -114,7 +114,7 @@ for t = 1,params.maxiter,params.batchsize do
114114
learningRates = etas,
115115
momentum = params.momentum}
116116
_,fs = optim.sgd(feval, x, sgdconf)
117-
err = err + fs[1]
117+
err = err + fs[1]*params.batchsize -- so that err is indep of batch size
118118

119119
-- normalize
120120
if params.model:find('psd') then
@@ -124,7 +124,7 @@ for t = 1,params.maxiter,params.batchsize do
124124
--------------------------------------------------------------------
125125
-- compute statistics / report error
126126
--
127-
if math.fmod(t , params.statinterval) == 0 then
127+
if iter*params.batchsize >= params.statinterval then
128128

129129
-- report
130130
print('==> iteration = ' .. t .. ', average loss = ' .. err/params.statinterval)

0 commit comments

Comments
 (0)