Skip to content
Prev Previous commit
Next Next commit
implement tau
  • Loading branch information
glenn-jocher committed Feb 25, 2022
commit 41adeaa519ee01fe4286ed61f3ff0a0e2f4433cc
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)

# EMA
ema = ModelEMA(model, k=opt.ema_k) if RANK in [-1, 0] else None
ema = ModelEMA(model) if RANK in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
Expand Down Expand Up @@ -481,7 +481,6 @@ def parse_opt(known=False):
parser.add_argument('--quad', action='store_true', help='quad dataloader')
parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
parser.add_argument('--ema-k', type=float, default=0.001, help='EMA ramp constant')
parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
Expand Down
2 changes: 1 addition & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(self, model, decay=0.9999, k=0.001, updates=0):
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x * k)) # decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)

Expand Down