@@ -536,60 +536,56 @@ heuristic implemented here gives up on much further optimization.
536536
537537
538538The choice of when to stop is a
539- judgement call and a few heuristics exist*** , but these tutorials will make use
539+ judgement call and a few heuristics exist, but these tutorials will make use
540540of a strategy based on a geometrically increasing amount of patience.
541541
542542.. code-block:: python
543543
544- # PRE-CONDITION
545- # params refers to [initialized] parameters of our model
546-
547544 # early-stopping parameters
548- n_iter = 100 # the maximal number of iterations of the
549- # entire dataset considered
550- patience = 5000 # look at this many training examples regardless
551- patience_increase = 2 # wait this much longer when a new best
552- # validation error is found
545+ patience = 5000 # look as this many examples regardless
546+ patience_increase = 2 # wait this much longer when a new best is
547+ # found
553548 improvement_threshold = 0.995 # a relative improvement of this much is
554549 # considered significant
555- validation_frequency = min(2500, patience/2.)
556- # make this many SGD updates between validations
550+ validation_frequency = min(n_train_batches, patience/2)
551+ # go through this many
552+ # minibatche before checking the network
553+ # on the validation set; in this case we
554+ # check every epoch
557555
558- # initialize cross-validation variables
559- best_params = None
556+ best_params = None
560557 best_validation_loss = float('inf')
558+ test_score = 0.
559+ start_time = time.clock()
561560
562- for iter in xrange( n_iter * len(train_batches) ) :
563-
564- # get epoch and minibatch index
565- epoch = iter / len(train_batches)
566- minibatch_index = iter % len(train_batches)
561+ done_looping = False
562+ epoch = 0
563+ while (epoch < n_epochs) and (not done_looping):
564+ epoch = epoch + 1
565+ for minibatch_index in xrange(n_train_batches):
567566
568- # get the minibatches corresponding to `iter` modulo
569- # `len(train_batches)`
570- x,y = train_batches[ minibatch_index ]
567+ d_loss_wrt_params = ... # compute gradient
568+ params -= learning_rate * d_loss_wrt_params # gradient descent
571569
570+ # iteration number
571+ iter = epoch * n_train_batches + minibatch_index
572+ # note that if we do `iter % validation_frequency` it will be
573+ # true for iter = 0 which we do not want
574+ if iter and iter % validation_frequency == 0:
572575
573- d_loss_wrt_params = ... # compute gradient
574- params -= learning_rate * d_loss_wrt_params # gradient descent
576+ this_validation_loss = ... # compute zero-one loss on validation set
575577
576- # note that if we do `iter % validation_frequency` it will be
577- # true for iter = 0 which we do not want
578- if (iter+1) % validation_frequency == 0:
578+ if this_validation_loss < best_validation_loss:
579579
580- this_validation_loss = ... # compute zero-one loss on validation set
580+ # improve patience if loss improvement is good enough
581+ if this_validation_loss < best_validation_loss*improvement_threshold:
581582
582- if this_validation_loss < best_validation_loss:
583-
584- # improve patience if loss improvement is good enough
585- if this_validation_loss < best_validation_loss*improvement_threshold:
586- patience = iter * patience_increase
587-
588- best_params = copy.deepcopy(params)
589- best_validation_loss = this_validation_loss
583+ patience = max(patience, iter * patience_increase)
584+ best_params = copy.deepcopy(params)
585+ best_validation_loss = this_validation_loss
590586
591- if patience <= iter:
592- break
587+ if patience <= iter:
588+ break
593589
594590 # POSTCONDITION:
595591 # best_params refers to the best out-of-sample parameters observed during the optimization
@@ -603,7 +599,7 @@ we just go back to the beginning of the training set and repeat.
603599 The ``validation_frequency`` should always be smaller than the
604600 ``patience``. The code should check at least two times how it
605601 performs before running out of patience. This is the reason we used
606- the formulation ``validation_frequency = min( 2500 , patience/2.)``
602+ the formulation ``validation_frequency = min( value , patience/2.)``
607603
608604.. note::
609605
0 commit comments