Created
September 26, 2016 22:57
-
-
Save piiswrong/d859ba805607775a8306d310e0e8345c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fit(args, network, data_loader, batch_end_callback=None): | |
# kvstore | |
kv = mx.kvstore.create(args.kv_store) | |
model_prefix = args.model_prefix | |
if model_prefix is not None: | |
model_prefix += "-%d" % (kv.rank) | |
save_model_prefix = args.save_model_prefix | |
if save_model_prefix is None: | |
save_model_prefix = model_prefix | |
# logging | |
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' | |
if 'log_dir' in args and args.log_dir is not None: | |
logging.basicConfig(level=logging.DEBUG, format=head) | |
if not os.path.exists(args.log_dir): | |
os.makedirs(args.log_dir) | |
if args.log_file is None: | |
log_file = (save_model_prefix if save_model_prefix else '') + datetime.now().strftime('_%Y_%m_%d-%H_%M.log') | |
log_file = log_file.replace('/', '-') | |
else: | |
log_file = args.log_file | |
log_file_full_name = os.path.join(args.log_dir, log_file) | |
handler = logging.FileHandler(log_file_full_name, mode='w') | |
formatter = logging.Formatter(head) | |
handler.setFormatter(formatter) | |
logging.getLogger().addHandler(handler) | |
logging.info('start with arguments %s', args) | |
else: | |
logging.basicConfig(level=logging.DEBUG, format=head) | |
logging.info('start with arguments %s', args) | |
# load model | |
model_args = {} | |
if args.load_epoch is not None: | |
assert model_prefix is not None | |
tmp = mx.model.load_checkpoint(model_prefix, args.load_epoch) | |
model_args = {'arg_params' : tmp.arg_params, | |
'aux_params' : tmp.aux_params, | |
'begin_epoch' : args.load_epoch} | |
# save model | |
checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix) | |
# data | |
(train, val) = data_loader(args, kv) | |
# train | |
devs = mx.cpu() if args.gpus is None else [ | |
mx.gpu(int(i)) for i in args.gpus.split(',')] | |
epoch_size = args.num_examples / args.batch_size | |
if args.kv_store == 'dist_sync': | |
epoch_size /= kv.num_workers | |
model_args['epoch_size'] = epoch_size | |
# disable kvstore for single device | |
if 'local' in kv.type and ( | |
args.gpus is None or len(args.gpus.split(',')) is 1): | |
kv = None | |
model = mx.mod.Module(symbol=network, context=devs) | |
optim = {'learning_rate': args.lr, 'wd': 1e-4, 'momentum': 0.9} | |
if 'lr_factor' in args and args.lr_factor < 1: | |
optim['lr_scheduler'] = mx.lr_scheduler.FactorScheduler( | |
step = max(int(epoch_size * args.lr_factor_epoch), 1), | |
factor = args.lr_factor) | |
if 'clip_gradient' in args and args.clip_gradient is not None: | |
optim['clip_gradient'] = args.clip_gradient | |
model_args['optimizer_params'] = optim | |
model_args['num_epoch'] = args.num_epochs | |
model_args['initializer'] = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) | |
eval_metric = [args.metric]#, mx.metric.create('top_k_accuracy', top_k = 5)] | |
if batch_end_callback is not None: | |
if not isinstance(batch_end_callback, list): | |
batch_end_callback = [batch_end_callback] | |
else: | |
batch_end_callback = [] | |
batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50)) | |
model.fit( | |
train_data = train, | |
eval_data = val, | |
eval_metric = eval_metric, | |
kvstore = kv, | |
batch_end_callback = batch_end_callback, | |
epoch_end_callback = checkpoint, | |
optimizer = 'nag', | |
**model_args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment