Last active
May 20, 2021 10:41
-
-
Save magesh-technovator/32e894463aa0744aedebfbfa1c29ba69 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import copy | |
import time | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
import os | |
def train_model(model, criterion, dataloaders, optimizer, metrics, bpath, num_epochs=3): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_loss = 1e10 | |
# Use gpu if available | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Initialize the log file for training and testing loss and metrics | |
fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \ | |
[f'Train_{m}' for m in metrics.keys()] + \ | |
[f'Test_{m}' for m in metrics.keys()] | |
with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
writer.writeheader() | |
for epoch in range(1, num_epochs+1): | |
print('Epoch {}/{}'.format(epoch, num_epochs)) | |
print('-' * 10) | |
# Each epoch has a training and validation phase | |
# Initialize batch summary | |
batchsummary = {a: [0] for a in fieldnames} | |
for phase in ['Train', 'Test']: | |
if phase == 'Train': | |
model.train() # Set model to training mode | |
else: | |
model.eval() # Set model to evaluate mode | |
# Iterate over data. | |
for sample in tqdm(iter(dataloaders[phase])): | |
inputs = sample['image'].to(device) | |
masks = sample['mask'].to(device) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# track history if only in train | |
with torch.set_grad_enabled(phase == 'Train'): | |
outputs = model(inputs) | |
loss = criterion(outputs['out'], masks) | |
y_pred = outputs['out'].data.cpu().numpy().ravel() | |
y_true = masks.data.cpu().numpy().ravel() | |
for name, metric in metrics.items(): | |
if name == 'f1_score': | |
# Use a classification threshold of 0.1 | |
batchsummary[f'{phase}_{name}'].append( | |
metric(y_true > 0, y_pred > 0.1)) | |
else: | |
batchsummary[f'{phase}_{name}'].append( | |
metric(y_true.astype('uint8'), y_pred)) | |
# backward + optimize only if in training phase | |
if phase == 'Train': | |
loss.backward() | |
optimizer.step() | |
batchsummary['epoch'] = epoch | |
epoch_loss = loss | |
batchsummary[f'{phase}_loss'] = epoch_loss.item() | |
print('{} Loss: {:.4f}'.format( | |
phase, loss)) | |
for field in fieldnames[3:]: | |
batchsummary[field] = np.mean(batchsummary[field]) | |
print(batchsummary) | |
with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile: | |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | |
writer.writerow(batchsummary) | |
# deep copy the model | |
if phase == 'Test' and loss < best_loss: | |
best_loss = loss | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
time_elapsed = time.time() - since | |
print('Training complete in {:.0f}m {:.0f}s'.format( | |
time_elapsed // 60, time_elapsed % 60)) | |
print('Lowest Loss: {:4f}'.format(best_loss)) | |
# load best model weights | |
model.load_state_dict(best_model_wts) | |
return model | |
epochs = 25 | |
bpath = "../exp/" | |
# Specify the loss function | |
criterion = torch.nn.MSELoss(reduction='mean') | |
# Specify the optimizer with a lower learning rate | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
# Specify the evalutation metrics | |
metrics = {'f1_score': f1_score, 'auroc': roc_auc_score} | |
trained_model = train_model(model, criterion, dataloaders, | |
optimizer, bpath=bpath, metrics=metrics, num_epochs=epochs) | |
torch.save(trained_model, os.path.join(bpath, 'weights.pt')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment