Skip to content

Instantly share code, notes, and snippets.

@pureexe
Created May 20, 2020 12:50
Show Gist options
  • Select an option

  • Save pureexe/821aca3d32b571b41aefbcd03afb6dc4 to your computer and use it in GitHub Desktop.

Select an option

Save pureexe/821aca3d32b571b41aefbcd03afb6dc4 to your computer and use it in GitHub Desktop.
from __future__ import division
from __future__ import print_function
from utils.utils import *
from utils.mpi_utils import outputMPI, OrbiterDataset, evaluation, render_video
from skimage import io, transform
from utils.mlp import *
import argparse
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
from matplotlib.pyplot import imread, imsave, imshow
from scipy.ndimage import gaussian_filter
import torch.nn.functional as F
import torch.nn as nn
import torch as pt
# import torch_sampling
import numpy as np
import os
import sys
import time
import cv2
from itertools import repeat
from functools import reduce
from random import shuffle
from torch.utils.data import SubsetRandomSampler
import socket
import pandas as pd
from utils.sfm_utils import SfMData
import json
import shutil
# import kornia
#run @v3 python train_random_ray.py -model_dir=ablation/hidden/128 -dataset=nerf_llff_data/fern -train_ratio=0.875 -invz -ray=12000 -offset=50 -random_split -hidden 128
parser = argparse.ArgumentParser()
parser.add_argument('-layers', type=int, default=12)
parser.add_argument('-sublayers', type=int, default=6)
parser.add_argument('-epochs', type=int, default=-1)
parser.add_argument('-steps', type=int, default=10)
parser.add_argument('-ray', type=int, default=10000)
parser.add_argument('-ray_render', type=int, default=100000) #ray that use when predict mpi a and c
parser.add_argument('-tb_saveimage', type=int, default=10)
parser.add_argument('-tb_savempi', type=int, default=10)
parser.add_argument('-sigmoid_offset', type=float, default=5)
# parser.add_argument('-gradientclip', type=float, default=0.01)
parser.add_argument('-hidden', type=int, default=256)
parser.add_argument('-mlp', type=int, default=4)
parser.add_argument('-pos_level', type=int, default=16)
parser.add_argument('-depth_level', type=int, default=8)
parser.add_argument('-mlp_downscale', type=float, default=4)
parser.add_argument('-lrelu_slope', type=float, default=0.01)
parser.add_argument('-latent', type=int, default=0)
parser.add_argument('-concat_layer', type=int, default=4)
parser.add_argument('-neighbor', type=int, default=0)
parser.add_argument('-decay_step', type=int, default=10000)
parser.add_argument('-decay_rate', type=float, default=0.5)
parser.add_argument('-offset', type=int, default=250)
parser.add_argument('-l1', type=float, default=20000)
parser.add_argument('-gradloss', type=float, default=5000)
parser.add_argument('-depsilon', type=float, default=0.015)
parser.add_argument('-train_ratio', type=float, default=0.875)
parser.add_argument('-tvc', type=float, default=0.0025)
parser.add_argument('-partialtvc', action='store_true')
parser.add_argument('-dmin', type=float, default=-1)
parser.add_argument('-dmax', type=float, default=-1)
parser.add_argument('-scale', type=float, default=-1)
parser.add_argument('-lr', type=float, default=0.01)
parser.add_argument('-nerf_width', type=int, default=1008)
parser.add_argument('-deepview_width', type=int, default=1260)
parser.add_argument('-w800', action='store_true')
parser.add_argument('-model_dir', type=str, default="e1")
parser.add_argument('-restart', action='store_true')
parser.add_argument('-clean', action='store_true')
parser.add_argument('-invz', action='store_true')
# adaptive sampling
parser.add_argument('-adaptive', action='store_true')
parser.add_argument('-adaptive_scale', type=float, default=1)
parser.add_argument('-adaptive_bias', type=float, default=1)
parser.add_argument('-constrgblr', action='store_true')
parser.add_argument('-stochastic', action='store_true')
parser.add_argument('-stochastic_upsampling', type=int, default=1)
parser.add_argument('-noeval', action='store_true')
parser.add_argument('-predict', action='store_true')
parser.add_argument('-nomask', action='store_true')
parser.add_argument('-random_split', action='store_true')
parser.add_argument('-pretrained', type=str, default="")
parser.add_argument('-dataset', type=str, default="scene_063")
parser.add_argument('-ref_img', type=str, default="") #cam_06/image_009
parser.add_argument('-img_wildcard', type=str, default="")
parser.add_argument('-submlp', type=int, default=0)
parser.add_argument('-submlp_id', type=int, default=0)
parser.add_argument('-submlp_overlap', type=float, default=0.5)
args = parser.parse_args()
def computeHomography(sfm, feature, d):
r = feature['r'][0]
t = feature['t'][0]
fx = feature['fx'][0]
fy = feature['fy'][0]
px = feature['px'][0]
py = feature['py'][0]
new_r = pt.matmul(r, sfm.ref_rT)
new_t = pt.matmul(pt.matmul(-r, sfm.ref_rT), sfm.ref_t) + t
n = pt.tensor([[0.0, 0.0, 1.0]])
Ha = new_r.t()
Hb = pt.matmul(pt.matmul(pt.matmul(Ha, new_t), n), Ha)
Hc = pt.matmul(pt.matmul(n, Ha), new_t)[0]
ki = pt.tensor([[fx, 0, px],
[0, fy, py],
[0, 0, 1]], dtype=pt.float).inverse()
t = sfm.ref_cam
ref_k = pt.tensor(
[[t['fx'], 0, t['px']],
[0, t['fy'], t['py']],
[0, 0, 1]])
return pt.matmul(pt.matmul(ref_k, Ha + Hb/(-d-Hc)), ki)
def computeHomoWarp(sfm,
input_shape, input_offset,
output_shape, selection,
feature, planes, inv=False):
# coords [sel, 2]
coords = pt.stack([selection % output_shape[1], selection / output_shape[1]], -1).float()
coords = pt.cat([coords, pt.ones_like(coords[:, :1])], -1).cuda()
cxys = []
for i, v in enumerate(planes):
H = computeHomography(sfm, feature, v)
if inv:
H = H.inverse()
newCoords = pt.matmul(coords, H.t().cuda())
cxys.append(
(newCoords[None, :, :2] / newCoords[:, 2:] + input_offset) /
(pt.tensor([input_shape[1]-1, input_shape[0]-1]).cuda()) * 2 - 1)
# n, sel, 2
warp = pt.cat(cxys, 0)
warp = warp.view(warp.shape[0], warp.shape[1], 1, 2)
warp2 = pt.cat([cxys[0], cxys[-1]], 0)
warp2 = warp2.view(warp2.shape[0], warp2.shape[1], 1, 2)
tmp = (warp2 >=-1) & (warp2 <=1)
mask = (tmp[0, :, :, 0] & tmp[0, :, :, 1] & tmp[1, :, :, 0] & tmp[1, :, :, 1]).float()
return warp, mask
# return offset for xo sampling
def getPatch(big_shape, small_shape):
parts = [math.ceil(big_shape[0] / small_shape[0]), math.ceil(big_shape[1] / small_shape[1])]
pick = np.random.randint(0, parts)
return [0 if parts[0] == 1 else int(np.round(pick[0] / (parts[0] - 1) * (big_shape[0] - small_shape[0]))),
0 if parts[1] == 1 else int(np.round(pick[1] / (parts[1] - 1) * (big_shape[1] - small_shape[1])))]
def getPatchWithinRefGrid(sfm,
patch_shape, train_shape,
ref_offset, ref_shape,
feature, planes):
coords = (pt.Tensor([ref_offset[1], ref_offset[0], 0]) +
pt.Tensor([[0, 0, 1],
[ref_shape[1]-1, 0, 1],
[0, ref_shape[0]-1, 1],
[ref_shape[1]-1, ref_shape[0]-1, 1]]))
minx = miny = 1e10
maxx = maxy = -1e10
for i, v in [planes[0], planes[-1]]:
H = computeHomography(sfm, feature, v).inverse()
newCoords = pt.matmul(coords, H.t())
newCoords = newCoords[:, :2] / newCoords[:, 2:]
minx = min(minx, math.floor(pt.min(newCoords[:, 0])))
maxx = max(maxx, math.ceil(pt.max(newCoords[:, 0])))
miny = min(miny, math.floor(pt.min(newCoords[:, 1])))
maxy = max(maxy, math.ceil(pt.max(newCoords[:, 1])))
minx = np.clip(minx, 0, train_shape[1] - patch_shape[1])
maxx = np.clip(maxx, patch_shape[1], train_shape[1])
miny = np.clip(miny, 0, train_shape[0] - patch_shape[0])
maxy = np.clip(maxy, patch_shape[0], train_shape[0])
offset = getPatch([maxy - miny, maxx - minx], patch_shape)
return [miny + offset[0], minx + offset[1]]
def getPlanes(sfm):
if sfm.invz:
return 1/np.linspace(1, sfm.dmin / sfm.dmax, args.layers * args.sublayers) * sfm.dmin
else:
return np.linspace(sfm.dmin, sfm.dmax, args.layers * args.sublayers)
def totalVariation(images):
pixel_dif1 = images[:, :, 1:, :] - images[:, :, :-1, :]
pixel_dif2 = images[:, :, :, 1:] - images[:, :, :, :-1]
sum_axis = [1, 2, 3]
tot_var = (
pt.sum(pt.abs(pixel_dif1), dim=sum_axis) +
pt.sum(pt.abs(pixel_dif2), dim=sum_axis))
return tot_var / (images.shape[2]-1) / (images.shape[3]-1) * 306081
def stochasticize(planes, stoc):
if not args.stochastic:
return planes
linplanes = planes
if args.invz:
linplanes = 1 / planes
linplanes = np.concatenate([linplanes, [2 * linplanes[-1] - linplanes[-2]]])
linplanes = linplanes[1:] * stoc + linplanes[:-1] * (1-stoc)
if args.invz:
return 1 / linplanes
return linplanes
class Network2(nn.Module):
def __init__(self, shape, sfm):
super(Network2, self).__init__()
mpi_c = pt.empty((shape[0], 3, shape[2], shape[3])).uniform_(-2, 2)
self.activation = nn.LeakyReLU(args.lrelu_slope)
norm = lambda x: x # legacy reason
if args.latent > 0:
print("unsupported")
exit()
self.latent = nn.Parameter(pt.empty((1, args.latent, shape[2], shape[3])).uniform_(-1, 1))
self.seq1 = nn.DataParallel(MLP(args.mlp, args.hidden, args.pos_level, args.depth_level, args.latent, args.concat_layer, args.lrelu_slope))
self.mpi_c = nn.Parameter(mpi_c)
if sfm.dmin < 0 or sfm.dmax < 0:
raise ValueError("invalid dmin dmax")
self.planes = getPlanes(sfm)
print(sfm.dmin, sfm.dmax, sfm.invz)
print('Mpi Size: {}'.format(self.mpi_c.shape))
print('Using sublayers: {}'.format(args.sublayers > 0))
print('Number of positional encoding: {}'.format(args.pos_level))
print('Layer of MLP: {}'.format(args.mlp + 2))
print('Channel of MLP: {}'.format(args.hidden))
print('Sigmoid OFFSET: {}'.format(args.sigmoid_offset))
if args.invz:
print('Using inverse depth, Min depth: {}, Max depth: {}'.format(self.planes[0], self.planes[-1]))
else:
print('NOT Using inverse depth, Min depth: {}, Max depth: {}'.format(self.planes[0], self.planes[-1]))
print(self.seq1)
print(self.planes)
def forward(self, sfm, feature, output_shape, selection, stoc=0):
# mpi_a: layers * sublayers, 1, h, w
# mpi_c: layers, 3, h, w -> layers * sublayers, 3, h, w
# self.mpi_a = self.seq1(coords)
mpi_c_tiled = pt.repeat_interleave(self.mpi_c, args.sublayers, 0)
mpi_sig = pt.sigmoid(mpi_c_tiled)
self.mpi_sig = mpi_sig
# (n, sel, 1, 2), (n, sel, 1, 1)
warp, mask = computeHomoWarp(sfm,
mpi_sig.shape[-2:],
args.offset,
output_shape, selection,
feature, stochasticize(self.planes, stoc=stoc))
bigcoords = []
n = mpi_sig.shape[0]
for i in range(n):
coords = pt.cat([encode(warp[i,:,:,0], args.pos_level), encode(warp[i,:,:,1], args.pos_level)], -1)
coords3d = pt.cat([coords, encode(pt.ones_like(coords[:, :, 0]) * (i + stoc) / (n-1) * 2 - 1, args.depth_level).cuda()], -1)
# sel, 1, code
bigcoords.append(coords3d)
# n, sel, 1, code
bigcoords = pt.stack(bigcoords, 0)
self.mpi_a = self.seq1(bigcoords)
#self.mpi_a = self.seq1(bigcoords)
self.mpi_a = self.mpi_a.view(self.mpi_a.shape[0], 1, self.mpi_a.shape[1], self.mpi_a.shape[2])
mpi_a_sig = pt.sigmoid(self.mpi_a - args.sigmoid_offset)
samples = F.grid_sample(mpi_sig, warp, align_corners=True)
if args.partialtvc:
self.partialtvc = pt.mean(totalVariation(samples))
weight = pt.cumprod(1 - pt.cat([pt.zeros_like(mpi_a_sig[:1]), mpi_a_sig[:-1]], 0), 0)
output = pt.sum(weight * samples * mpi_a_sig, dim=0, keepdim=True)
return output, mask
# x: h, w
def encode(x, lev=8):
# h, w, lev
x = x[:, :, None].repeat([1, 1, lev])
ex = pt.Tensor([0.5 * np.pi * (2 ** i) for i in range(lev)]).cuda()
x *= ex[None, None, :]
# h, w, 2*lev
return pt.cat([pt.sin(x), pt.cos(x)], -1)
def generateAlpha(model, dataset, dataloader, runpath, suffix=""):
if not args.noeval:
evaluation(model, dataset, dataloader, args.ray , runpath + args.model_dir + suffix)
pt.cuda.empty_cache()
sh = int(model.mpi_c.shape[-2])
sw = int(model.mpi_c.shape[-1])
print(sh, sw)
y, x = pt.meshgrid([
(pt.arange(0, sh, dtype=pt.float)) / (sh-1) * 2 - 1,
(pt.arange(0, sw, dtype=pt.float)) / (sw-1) * 2 - 1])
coords = pt.cat([encode(x.cuda(), args.pos_level), encode(y.cuda(), args.pos_level)], -1)
model.eval()
n = args.layers * args.sublayers
imgs = []
if args.latent > 0:
warp = pt.stack([x, y], -1).unsqueeze(0).cuda()
with pt.no_grad():
for i in range(n):
for j in range(args.stochastic_upsampling):
tj = j / args.stochastic_upsampling
coords3d = pt.cat([coords, encode(pt.ones_like(coords[:, :, 0]) * (i + tj) / (n-1) * 2 - 1, args.depth_level).cuda()], -1)
if args.latent > 0:
lat = F.grid_sample(model.latent, warp, align_corners=True)
coords3d = pt.cat([coords3d, lat[0]], -1)
coords3d_linear = coords3d.view(coords3d.shape[0]*coords3d.shape[1], 1, coords3d.shape[2])
num = math.ceil(coords3d.shape[0]*coords3d.shape[1]/args.ray_render)
stored_ray = []
for i in range(num):
if i < num -1 :
selection = pt.arange(args.ray_render * i, args.ray_render * (i + 1))
else:
selection = pt.arange(args.ray_render * i, coords3d.shape[0] * coords3d.shape[1])
smaller_coords3d = coords3d_linear[selection]
mpi_a = model.seq1(smaller_coords3d)
stored_ray.append(mpi_a)
mpi_a = pt.cat(stored_ray, 0)
mpi_a = mpi_a.view(1, coords3d.shape[0], coords3d.shape[1])
#mpi_a = mpi_a.view(mpi_a.shape[0], 1, mpi_a.shape[1], mpi_a.shape[2])
imgs.append(mpi_a- args.sigmoid_offset)
mpi_c_tiled = pt.repeat_interleave(model.mpi_c, args.sublayers * args.stochastic_upsampling, 0)
# for i in range(mpi_c_tiled.shape[0]):
# mpi_c_tiled[i] = pt.tensor(Rainbow(i / (mpi_c_tiled.shape[0]-1))).view(3, 1, 1)
mpi_c_tiled = pt.sigmoid(mpi_c_tiled)
print(mpi_c_tiled.shape)
print(pt.stack(imgs, 0).shape)
planes = getPlanes(dataset.sfm)
if args.stochastic_upsampling > 1:
sets = [stochasticize(planes, k / args.stochastic_upsampling) for k in range(args.stochastic_upsampling)]
planes = [val for tup in zip(*sets) for val in tup]
maxcol = int(16000 / sw)
with pt.no_grad():
detached_mpi_c = mpi_c_tiled.cpu().numpy()
detached_mpi_a = (1 - pt.pow(1 - pt.sigmoid(pt.stack(imgs, 0)), 1/args.stochastic_upsampling)).cpu().numpy()
mpi_cat = np.concatenate([detached_mpi_c,detached_mpi_a],1)
mpi_cat = np.transpose(mpi_cat,[0, 2, 3, 1])
outputMPI(mpi_cat,
dataset.sfm,
planes,
runpath + args.model_dir + suffix,
args.layers * args.sublayers * args.stochastic_upsampling,
1,
args.offset,
args.invz,
maxcol=maxcol)
def train():
# args.model_dir += "_lay%d_lrc%g_lra%g" % (args.layers, args.lrc, args.lra)
if args.restart or args.clean:
os.system("rm -rf " + "runs/" + args.model_dir)
if args.clean:
exit()
dpath = getOrbiterDataset(args.dataset, '/data/orbiter/datasets/nerf_llff_data/')
# dpath = '/data/orbiter/datasets/' + args.dataset
if os.path.exists(dpath + "/ref_image.txt"):
with open(dpath + "/ref_image.txt", "r") as fi:
args.ref_img = str(fi.readline().strip())
if args.scale == -1:
sfm = SfMData(dpath,
ref_img=args.ref_img,
dmin=args.dmin,
dmax=args.dmax,
scale=1)
args.scale = args.deepview_width / sfm.ref_cam['width']
else:
if args.scale == -1:
sfm = SfMData(dpath,
ref_img=args.ref_img,
dmin=args.dmin,
dmax=args.dmax,
scale=1)
args.scale = args.nerf_width / sfm.ref_cam['width']
dataset = OrbiterDataset(dpath, ref_img=args.ref_img, scale=args.scale,
dmin=args.dmin,
dmax=args.dmax,
invz=args.invz,
neighbor=args.neighbor,
img_wildcard=args.img_wildcard)
pt.manual_seed(0)
if args.random_split:
sampler_train, sampler_val = generateSubsetSamplers(len(dataset), ratio=args.train_ratio)
dataloader_train = DataLoader(dataset, batch_size=1, sampler=sampler_train)
dataloader_val = DataLoader(dataset, batch_size=1, sampler=sampler_val)
print('TRAINING IMAGES: {}'.format(len(dataloader_train)))
print('VALIDATE IMAGES: {}'.format(len(dataloader_val)))
else:
def get_indices(ty):
if os.path.exists(dpath + "/{}_image.txt".format(ty)):
data = []
with open(dpath + "/{}_image.txt".format(ty), "r") as fi:
for line in fi.readlines():
count = 0
for img in dataset.imgs:
if line.strip() in img['path'] and args.img_wildcard in img['path']:
data.append(count)
break
count += 1
return data
else:
raise('No CONFIG TRAINING FILE')
if os.path.exists(os.path.join(dpath,'poses_bounds.npy')):
#LLFF dataset which is use every 8 images to be training data
indices_total = list(range(len(dataset.imgs)))
indices_val = indices_total[::8]
indices_train = list(filter(lambda x: x not in indices_val, indices_total))
else:
indices_train = get_indices('train')
indices_val = get_indices('val')
sampler_train = SubsetRandomSampler(indices_train)
sampler_val = SubsetRandomSampler(indices_val)
dataloader_train = DataLoader(dataset, batch_size=1, sampler = sampler_train)
dataloader_val = DataLoader(dataset, batch_size=1, sampler = sampler_val)
print('TRAINING IMAGES: {}'.format(len(dataloader_train)))
print('VALIDATE IMAGES: {}'.format(len(dataloader_val)))
Network = Network2
#if 'room' in args.dataset:
# dataset.sfm.dmin = 1.2
# dataset.sfm.dmax = 20.0
model = Network((args.layers,
4,
dataset.sfm.ref_cam['height'] + args.offset * 2,
dataset.sfm.ref_cam['width'] + args.offset * 2,
), dataset.sfm).cuda()
runpath = "runs/"
writer = SummaryWriter(runpath + args.model_dir)
ckpt = runpath + args.model_dir + "/ckpt.pt"
checkpoint = None
lr = args.lr
if args.constrgblr:
optimizer = pt.optim.Adam([
{'params': model.seq1.parameters(), 'lr': lr},
{'params': model.mpi_c, 'lr': lr}])
else:
optimizer = pt.optim.Adam(model.parameters(), lr=lr)
with open(runpath + args.model_dir + "/args.json", 'w') as out:
json.dump(vars(args), out, indent=2, sort_keys=True)
os.system("cp " + os.path.abspath(__file__) + " " + runpath + args.model_dir + "/")
start_epoch = 0
if args.pretrained != "":
checkpoint = pt.load(runpath + args.pretrained + "/ckpt.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print("Loading %s model at epoch %d" % (args.pretrained, start_epoch))
if os.path.exists(ckpt):
checkpoint = pt.load(ckpt)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print("Loading model at epoch %d" % start_epoch)
step = start_epoch * len(sampler_train)
if args.epochs < 0 and args.steps < 0:
print("Need to specify epochs or steps")
exit()
if args.epochs < 0:
args.epochs = int(np.ceil(args.steps / len(sampler_train)))
ts = TrainingStatus(num_steps=args.epochs * len(sampler_train))
if args.predict:
generateAlpha(model, dataset, dataloader_val,runpath)
exit()
relative_step = 0
for epoch in range(start_epoch, args.epochs):
epoch_loss_total = 0
epoch_mse = 0
model.train()
for i, feature in enumerate(dataloader_train):
ts.tic()
optimizer.zero_grad()
downscale = int(args.mlp_downscale)
output_shape = feature['image'].shape[-2:]
if False and args.adaptive:
gauss = pt.clamp(kornia.filters.gaussian_blur2d(feature['image'], (3, 3), (0.8, 0.8)), 0, 1)
# h, w
edge = pt.mean(kornia.filters.sobel(gauss), 1)
prob = edge.view(-1) * args.adaptive_scale + args.adaptive_bias
prob = prob / pt.sum(prob)
# np.random.choice() probabilies never summed up to one due to floating point precision!
# selection = pt.tensor(np.random.choice(output_shape[0] * output_shape[1], args.ray, replace=False, p=prob))
selection = torch_sampling.choice(pt.tensor(list(range(output_shape[0] * output_shape[1]))), args.ray, True, prob)
else:
selection = pt.tensor(np.random.choice((output_shape[0] - 1) * (output_shape[1] -1), int(args.ray/3), replace=False))
# selection = pt.randperm(output_shape[0] * output_shape[1])[:20000]
real_select = pt.zeros(3 * int(args.ray/3)).to(pt.long)
real_select[0::3] = selection
real_select[1::3] = selection + 1
real_select[2::3] = selection + output_shape[1]
gt = feature['image']
gt = gt.view(gt.shape[0], gt.shape[1], gt.shape[2] * gt.shape[3])
gt = gt[:, :, real_select, None].cuda()
output, mask = model(dataset.sfm, feature, output_shape, real_select,
stoc=np.random.uniform() if args.stochastic else 0)
descaling = output_shape[0] * output_shape[1] / (output_shape[0] * output_shape[1])
mse = pt.mean((mask * (output - gt))**2) * descaling
if args.nomask:
loss_recon = args.l1 * pt.mean(pt.abs(output - gt))
else:
loss_recon = args.l1 * pt.mean(pt.abs(mask * (output - gt)))
if args.partialtvc:
tvc = args.tvc * model.partialtvc
else:
tvc = args.tvc * pt.mean(totalVariation(pt.sigmoid(model.mpi_c[:, :3])))
loss_total = (loss_recon + tvc)
if args.gradloss > 0:
oy = output[:, :, 0::3, :] - output[:, :, 1::3, :]
ox = output[:, :, 0::3, :] - output[:, :, 2::3, :]
gy = gt[:, :, 0::3, :] - gt[:, :, 1::3, :]
gx = gt[:, :, 0::3, :] - gt[:, :, 2::3, :]
masky = pt.abs(gy) < 0.005
maskx = pt.abs(gx) < 0.005
loss_total = loss_total + args.gradloss * (pt.mean(maskx * pt.abs(ox - gx)) + pt.mean(masky * pt.abs(oy - gy)))
epoch_loss_total += loss_total
epoch_mse += mse
loss_total.backward()
optimizer.step()
if step % args.tb_saveimage == 0:
# writer.add_image('images/0_gt', pt.cat([gt[0], output[0]], 1), step)
# writer.add_image('images/1_alpha', make_grid(pt.sigmoid(model.mpi_a - args.sigmoid_offset)[::int(args.layers * args.sublayers / 8)], 4), step)
writer.add_image('images/2_mpic', make_grid(F.interpolate(pt.sigmoid(model.mpi_c), (int(model.mpi_c.shape[-2] * 0.3), int(model.mpi_c.shape[-1] * 0.3)), mode='area'), 4), step)
if step % args.tb_savempi == 0 and step > 0:
generateAlpha(model, dataset, dataloader_val, runpath, "/%06d" % step)
pt.cuda.empty_cache()
step += 1
relative_step += 1
if step % args.decay_step == 0:
lr *= args.decay_rate
optimizer.param_groups[0]['lr'] = lr
print(ts.toc(step, loss_total.item()))
ts.tic()
epoch_loss_total /= len(sampler_train)
epoch_mse /= len(sampler_train)
writer.add_scalar('loss/total', epoch_loss_total, step)
writer.add_scalar('loss/mse', epoch_mse, step)
if (epoch+1) % 20 == 0 or epoch == args.epochs-1:
print("checkpointing model...")
pt.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, runpath + args.model_dir + "/ckpt.pt")
print('Finished Training')
generateAlpha(model, dataset, dataloader_val, runpath)
render_video(model, dataset, args.ray, os.path.join(runpath, 'video_output', args.model_dir))
def main():
train()
if __name__ == "__main__":
sys.excepthook = colored_hook(os.path.dirname(os.path.realpath(__file__)))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment