def exp_lr_scheduler(optimizer, global_step, init_lr, decay_steps, decay_rate, lr_clip, staircase=True): """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" if staircase: lr = init_lr * decay_rate**(global_step // decay_steps) else: lr = init_lr * decay_rate**(global_step / decay_steps) lr = max(lr, lr_clip) if global_step % decay_steps == 0: print('LR is set to {}'.format(lr)) for param_group in optimizer.param_groups: param_group['lr'] = lr