Created
May 20, 2020 12:50
-
-
Save pureexe/b9f001d0888794129512dc9720cb95aa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import division | |
| from __future__ import print_function | |
| from utils.utils import * | |
| from utils.mpi_utils import outputMPI, OrbiterDataset, evaluation, render_video | |
| from skimage import io, transform | |
| from utils.mlp import * | |
| import argparse | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.utils.data.dataloader import default_collate | |
| from torchvision.utils import save_image, make_grid | |
| from torch.utils.tensorboard import SummaryWriter | |
| from matplotlib.pyplot import imread, imsave, imshow | |
| from scipy.ndimage import gaussian_filter | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import torch as pt | |
| # import torch_sampling | |
| import numpy as np | |
| import os | |
| import sys | |
| import time | |
| import cv2 | |
| from itertools import repeat | |
| from functools import reduce | |
| from random import shuffle | |
| from torch.utils.data import SubsetRandomSampler | |
| import socket | |
| import pandas as pd | |
| from utils.sfm_utils import SfMData | |
| import json | |
| import shutil | |
| # import kornia | |
| #run @v3 python train_random_ray.py -model_dir=ablation/hidden/128 -dataset=nerf_llff_data/fern -train_ratio=0.875 -invz -ray=12000 -offset=50 -random_split -hidden 128 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-layers', type=int, default=12) | |
| parser.add_argument('-sublayers', type=int, default=6) | |
| parser.add_argument('-epochs', type=int, default=-1) | |
| parser.add_argument('-steps', type=int, default=10) | |
| parser.add_argument('-ray', type=int, default=10000) | |
| parser.add_argument('-ray_render', type=int, default=100000) #ray that use when predict mpi a and c | |
| parser.add_argument('-tb_saveimage', type=int, default=10) | |
| parser.add_argument('-tb_savempi', type=int, default=10) | |
| parser.add_argument('-sigmoid_offset', type=float, default=5) | |
| # parser.add_argument('-gradientclip', type=float, default=0.01) | |
| parser.add_argument('-hidden', type=int, default=256) | |
| parser.add_argument('-mlp', type=int, default=4) | |
| parser.add_argument('-pos_level', type=int, default=16) | |
| parser.add_argument('-depth_level', type=int, default=8) | |
| parser.add_argument('-mlp_downscale', type=float, default=4) | |
| parser.add_argument('-lrelu_slope', type=float, default=0.01) | |
| parser.add_argument('-latent', type=int, default=0) | |
| parser.add_argument('-concat_layer', type=int, default=4) | |
| parser.add_argument('-neighbor', type=int, default=0) | |
| parser.add_argument('-decay_step', type=int, default=10000) | |
| parser.add_argument('-decay_rate', type=float, default=0.5) | |
| parser.add_argument('-offset', type=int, default=250) | |
| parser.add_argument('-l1', type=float, default=20000) | |
| parser.add_argument('-gradloss', type=float, default=5000) | |
| parser.add_argument('-depsilon', type=float, default=0.015) | |
| parser.add_argument('-train_ratio', type=float, default=0.875) | |
| parser.add_argument('-tvc', type=float, default=0.0025) | |
| parser.add_argument('-partialtvc', action='store_true') | |
| parser.add_argument('-dmin', type=float, default=-1) | |
| parser.add_argument('-dmax', type=float, default=-1) | |
| parser.add_argument('-scale', type=float, default=-1) | |
| parser.add_argument('-lr', type=float, default=0.01) | |
| parser.add_argument('-nerf_width', type=int, default=1008) | |
| parser.add_argument('-deepview_width', type=int, default=1260) | |
| parser.add_argument('-w800', action='store_true') | |
| parser.add_argument('-model_dir', type=str, default="e1") | |
| parser.add_argument('-restart', action='store_true') | |
| parser.add_argument('-clean', action='store_true') | |
| parser.add_argument('-invz', action='store_true') | |
| # adaptive sampling | |
| parser.add_argument('-adaptive', action='store_true') | |
| parser.add_argument('-adaptive_scale', type=float, default=1) | |
| parser.add_argument('-adaptive_bias', type=float, default=1) | |
| parser.add_argument('-constrgblr', action='store_true') | |
| parser.add_argument('-stochastic', action='store_true') | |
| parser.add_argument('-stochastic_upsampling', type=int, default=1) | |
| parser.add_argument('-noeval', action='store_true') | |
| parser.add_argument('-predict', action='store_true') | |
| parser.add_argument('-nomask', action='store_true') | |
| parser.add_argument('-random_split', action='store_true') | |
| parser.add_argument('-pretrained', type=str, default="") | |
| parser.add_argument('-dataset', type=str, default="scene_063") | |
| parser.add_argument('-ref_img', type=str, default="") #cam_06/image_009 | |
| parser.add_argument('-img_wildcard', type=str, default="") | |
| parser.add_argument('-submlp', type=int, default=0) | |
| parser.add_argument('-submlp_id', type=int, default=0) | |
| parser.add_argument('-submlp_overlap', type=float, default=0.5) | |
| args = parser.parse_args() | |
| def computeHomography(sfm, feature, d): | |
| r = feature['r'][0] | |
| t = feature['t'][0] | |
| fx = feature['fx'][0] | |
| fy = feature['fy'][0] | |
| px = feature['px'][0] | |
| py = feature['py'][0] | |
| new_r = pt.matmul(r, sfm.ref_rT) | |
| new_t = pt.matmul(pt.matmul(-r, sfm.ref_rT), sfm.ref_t) + t | |
| n = pt.tensor([[0.0, 0.0, 1.0]]) | |
| Ha = new_r.t() | |
| Hb = pt.matmul(pt.matmul(pt.matmul(Ha, new_t), n), Ha) | |
| Hc = pt.matmul(pt.matmul(n, Ha), new_t)[0] | |
| ki = pt.tensor([[fx, 0, px], | |
| [0, fy, py], | |
| [0, 0, 1]], dtype=pt.float).inverse() | |
| t = sfm.ref_cam | |
| ref_k = pt.tensor( | |
| [[t['fx'], 0, t['px']], | |
| [0, t['fy'], t['py']], | |
| [0, 0, 1]]) | |
| return pt.matmul(pt.matmul(ref_k, Ha + Hb/(-d-Hc)), ki) | |
| def computeHomoWarp(sfm, | |
| input_shape, input_offset, | |
| output_shape, selection, | |
| feature, planes, inv=False): | |
| # coords [sel, 2] | |
| coords = pt.stack([selection % output_shape[1], selection / output_shape[1]], -1).float() | |
| coords = pt.cat([coords, pt.ones_like(coords[:, :1])], -1).cuda() | |
| cxys = [] | |
| for i, v in enumerate(planes): | |
| H = computeHomography(sfm, feature, v) | |
| if inv: | |
| H = H.inverse() | |
| newCoords = pt.matmul(coords, H.t().cuda()) | |
| cxys.append( | |
| (newCoords[None, :, :2] / newCoords[:, 2:] + input_offset) / | |
| (pt.tensor([input_shape[1]-1, input_shape[0]-1]).cuda()) * 2 - 1) | |
| # n, sel, 2 | |
| warp = pt.cat(cxys, 0) | |
| warp = warp.view(warp.shape[0], warp.shape[1], 1, 2) | |
| warp2 = pt.cat([cxys[0], cxys[-1]], 0) | |
| warp2 = warp2.view(warp2.shape[0], warp2.shape[1], 1, 2) | |
| tmp = (warp2 >=-1) & (warp2 <=1) | |
| mask = (tmp[0, :, :, 0] & tmp[0, :, :, 1] & tmp[1, :, :, 0] & tmp[1, :, :, 1]).float() | |
| return warp, mask | |
| # return offset for xo sampling | |
| def getPatch(big_shape, small_shape): | |
| parts = [math.ceil(big_shape[0] / small_shape[0]), math.ceil(big_shape[1] / small_shape[1])] | |
| pick = np.random.randint(0, parts) | |
| return [0 if parts[0] == 1 else int(np.round(pick[0] / (parts[0] - 1) * (big_shape[0] - small_shape[0]))), | |
| 0 if parts[1] == 1 else int(np.round(pick[1] / (parts[1] - 1) * (big_shape[1] - small_shape[1])))] | |
| def getPatchWithinRefGrid(sfm, | |
| patch_shape, train_shape, | |
| ref_offset, ref_shape, | |
| feature, planes): | |
| coords = (pt.Tensor([ref_offset[1], ref_offset[0], 0]) + | |
| pt.Tensor([[0, 0, 1], | |
| [ref_shape[1]-1, 0, 1], | |
| [0, ref_shape[0]-1, 1], | |
| [ref_shape[1]-1, ref_shape[0]-1, 1]])) | |
| minx = miny = 1e10 | |
| maxx = maxy = -1e10 | |
| for i, v in [planes[0], planes[-1]]: | |
| H = computeHomography(sfm, feature, v).inverse() | |
| newCoords = pt.matmul(coords, H.t()) | |
| newCoords = newCoords[:, :2] / newCoords[:, 2:] | |
| minx = min(minx, math.floor(pt.min(newCoords[:, 0]))) | |
| maxx = max(maxx, math.ceil(pt.max(newCoords[:, 0]))) | |
| miny = min(miny, math.floor(pt.min(newCoords[:, 1]))) | |
| maxy = max(maxy, math.ceil(pt.max(newCoords[:, 1]))) | |
| minx = np.clip(minx, 0, train_shape[1] - patch_shape[1]) | |
| maxx = np.clip(maxx, patch_shape[1], train_shape[1]) | |
| miny = np.clip(miny, 0, train_shape[0] - patch_shape[0]) | |
| maxy = np.clip(maxy, patch_shape[0], train_shape[0]) | |
| offset = getPatch([maxy - miny, maxx - minx], patch_shape) | |
| return [miny + offset[0], minx + offset[1]] | |
| def getPlanes(sfm): | |
| if sfm.invz: | |
| return 1/np.linspace(1, sfm.dmin / sfm.dmax, args.layers * args.sublayers) * sfm.dmin | |
| else: | |
| return np.linspace(sfm.dmin, sfm.dmax, args.layers * args.sublayers) | |
| def totalVariation(images): | |
| pixel_dif1 = images[:, :, 1:, :] - images[:, :, :-1, :] | |
| pixel_dif2 = images[:, :, :, 1:] - images[:, :, :, :-1] | |
| sum_axis = [1, 2, 3] | |
| tot_var = ( | |
| pt.sum(pt.abs(pixel_dif1), dim=sum_axis) + | |
| pt.sum(pt.abs(pixel_dif2), dim=sum_axis)) | |
| return tot_var / (images.shape[2]-1) / (images.shape[3]-1) * 306081 | |
| def stochasticize(planes, stoc): | |
| if not args.stochastic: | |
| return planes | |
| linplanes = planes | |
| if args.invz: | |
| linplanes = 1 / planes | |
| linplanes = np.concatenate([linplanes, [2 * linplanes[-1] - linplanes[-2]]]) | |
| linplanes = linplanes[1:] * stoc + linplanes[:-1] * (1-stoc) | |
| if args.invz: | |
| return 1 / linplanes | |
| return linplanes | |
| class Network2(nn.Module): | |
| def __init__(self, shape, sfm): | |
| super(Network2, self).__init__() | |
| mpi_c = pt.empty((shape[0], 3, shape[2], shape[3])).uniform_(-2, 2) | |
| self.activation = nn.LeakyReLU(args.lrelu_slope) | |
| norm = lambda x: x # legacy reason | |
| if args.latent > 0: | |
| print("unsupported") | |
| exit() | |
| self.latent = nn.Parameter(pt.empty((1, args.latent, shape[2], shape[3])).uniform_(-1, 1)) | |
| self.seq1 = nn.DataParallel(MLP(args.mlp, args.hidden, args.pos_level, args.depth_level, args.latent, args.concat_layer, args.lrelu_slope)) | |
| self.mpi_c = nn.Parameter(mpi_c) | |
| if sfm.dmin < 0 or sfm.dmax < 0: | |
| raise ValueError("invalid dmin dmax") | |
| self.planes = getPlanes(sfm) | |
| print(sfm.dmin, sfm.dmax, sfm.invz) | |
| print('Mpi Size: {}'.format(self.mpi_c.shape)) | |
| print('Using sublayers: {}'.format(args.sublayers > 0)) | |
| print('Number of positional encoding: {}'.format(args.pos_level)) | |
| print('Layer of MLP: {}'.format(args.mlp + 2)) | |
| print('Channel of MLP: {}'.format(args.hidden)) | |
| print('Sigmoid OFFSET: {}'.format(args.sigmoid_offset)) | |
| if args.invz: | |
| print('Using inverse depth, Min depth: {}, Max depth: {}'.format(self.planes[0], self.planes[-1])) | |
| else: | |
| print('NOT Using inverse depth, Min depth: {}, Max depth: {}'.format(self.planes[0], self.planes[-1])) | |
| print(self.seq1) | |
| print(self.planes) | |
| def forward(self, sfm, feature, output_shape, selection, stoc=0): | |
| # mpi_a: layers * sublayers, 1, h, w | |
| # mpi_c: layers, 3, h, w -> layers * sublayers, 3, h, w | |
| # self.mpi_a = self.seq1(coords) | |
| mpi_c_tiled = pt.repeat_interleave(self.mpi_c, args.sublayers, 0) | |
| mpi_sig = pt.sigmoid(mpi_c_tiled) | |
| self.mpi_sig = mpi_sig | |
| # (n, sel, 1, 2), (n, sel, 1, 1) | |
| warp, mask = computeHomoWarp(sfm, | |
| mpi_sig.shape[-2:], | |
| args.offset, | |
| output_shape, selection, | |
| feature, stochasticize(self.planes, stoc=stoc)) | |
| bigcoords = [] | |
| n = mpi_sig.shape[0] | |
| for i in range(n): | |
| coords = pt.cat([encode(warp[i,:,:,0], args.pos_level), encode(warp[i,:,:,1], args.pos_level)], -1) | |
| coords3d = pt.cat([coords, encode(pt.ones_like(coords[:, :, 0]) * (i + stoc) / (n-1) * 2 - 1, args.depth_level).cuda()], -1) | |
| # sel, 1, code | |
| bigcoords.append(coords3d) | |
| # n, sel, 1, code | |
| bigcoords = pt.stack(bigcoords, 0) | |
| self.mpi_a = self.seq1(bigcoords) | |
| #self.mpi_a = self.seq1(bigcoords) | |
| self.mpi_a = self.mpi_a.view(self.mpi_a.shape[0], 1, self.mpi_a.shape[1], self.mpi_a.shape[2]) | |
| mpi_a_sig = pt.sigmoid(self.mpi_a - args.sigmoid_offset) | |
| samples = F.grid_sample(mpi_sig, warp, align_corners=True) | |
| if args.partialtvc: | |
| self.partialtvc = pt.mean(totalVariation(samples)) | |
| weight = pt.cumprod(1 - pt.cat([pt.zeros_like(mpi_a_sig[:1]), mpi_a_sig[:-1]], 0), 0) | |
| output = pt.sum(weight * samples * mpi_a_sig, dim=0, keepdim=True) | |
| return output, mask | |
| # x: h, w | |
| def encode(x, lev=8): | |
| # h, w, lev | |
| x = x[:, :, None].repeat([1, 1, lev]) | |
| ex = pt.Tensor([0.5 * np.pi * (2 ** i) for i in range(lev)]).cuda() | |
| x *= ex[None, None, :] | |
| # h, w, 2*lev | |
| return pt.cat([pt.sin(x), pt.cos(x)], -1) | |
| def generateAlpha(model, dataset, dataloader, runpath, suffix=""): | |
| if not args.noeval: | |
| evaluation(model, dataset, dataloader, args.ray , runpath + args.model_dir + suffix) | |
| pt.cuda.empty_cache() | |
| sh = int(model.mpi_c.shape[-2]) | |
| sw = int(model.mpi_c.shape[-1]) | |
| print(sh, sw) | |
| y, x = pt.meshgrid([ | |
| (pt.arange(0, sh, dtype=pt.float)) / (sh-1) * 2 - 1, | |
| (pt.arange(0, sw, dtype=pt.float)) / (sw-1) * 2 - 1]) | |
| coords = pt.cat([encode(x.cuda(), args.pos_level), encode(y.cuda(), args.pos_level)], -1) | |
| model.eval() | |
| n = args.layers * args.sublayers | |
| imgs = [] | |
| if args.latent > 0: | |
| warp = pt.stack([x, y], -1).unsqueeze(0).cuda() | |
| with pt.no_grad(): | |
| for i in range(n): | |
| for j in range(args.stochastic_upsampling): | |
| tj = j / args.stochastic_upsampling | |
| coords3d = pt.cat([coords, encode(pt.ones_like(coords[:, :, 0]) * (i + tj) / (n-1) * 2 - 1, args.depth_level).cuda()], -1) | |
| if args.latent > 0: | |
| lat = F.grid_sample(model.latent, warp, align_corners=True) | |
| coords3d = pt.cat([coords3d, lat[0]], -1) | |
| coords3d_linear = coords3d.view(coords3d.shape[0]*coords3d.shape[1], 1, coords3d.shape[2]) | |
| num = math.ceil(coords3d.shape[0]*coords3d.shape[1]/args.ray_render) | |
| stored_ray = [] | |
| for i in range(num): | |
| if i < num -1 : | |
| selection = pt.arange(args.ray_render * i, args.ray_render * (i + 1)) | |
| else: | |
| selection = pt.arange(args.ray_render * i, coords3d.shape[0] * coords3d.shape[1]) | |
| smaller_coords3d = coords3d_linear[selection] | |
| mpi_a = model.seq1(smaller_coords3d) | |
| stored_ray.append(mpi_a) | |
| mpi_a = pt.cat(stored_ray, 0) | |
| mpi_a = mpi_a.view(1, coords3d.shape[0], coords3d.shape[1]) | |
| #mpi_a = mpi_a.view(mpi_a.shape[0], 1, mpi_a.shape[1], mpi_a.shape[2]) | |
| imgs.append(mpi_a- args.sigmoid_offset) | |
| mpi_c_tiled = pt.repeat_interleave(model.mpi_c, args.sublayers * args.stochastic_upsampling, 0) | |
| # for i in range(mpi_c_tiled.shape[0]): | |
| # mpi_c_tiled[i] = pt.tensor(Rainbow(i / (mpi_c_tiled.shape[0]-1))).view(3, 1, 1) | |
| mpi_c_tiled = pt.sigmoid(mpi_c_tiled) | |
| print(mpi_c_tiled.shape) | |
| print(pt.stack(imgs, 0).shape) | |
| planes = getPlanes(dataset.sfm) | |
| if args.stochastic_upsampling > 1: | |
| sets = [stochasticize(planes, k / args.stochastic_upsampling) for k in range(args.stochastic_upsampling)] | |
| planes = [val for tup in zip(*sets) for val in tup] | |
| maxcol = int(16000 / sw) | |
| with pt.no_grad(): | |
| detached_mpi_c = mpi_c_tiled.cpu().numpy() | |
| detached_mpi_a = (1 - pt.pow(1 - pt.sigmoid(pt.stack(imgs, 0)), 1/args.stochastic_upsampling)).cpu().numpy() | |
| mpi_cat = np.concatenate([detached_mpi_c,detached_mpi_a],1) | |
| mpi_cat = np.transpose(mpi_cat,[0, 2, 3, 1]) | |
| outputMPI(mpi_cat, | |
| dataset.sfm, | |
| planes, | |
| runpath + args.model_dir + suffix, | |
| args.layers * args.sublayers * args.stochastic_upsampling, | |
| 1, | |
| args.offset, | |
| args.invz, | |
| maxcol=maxcol) | |
| def train(): | |
| # args.model_dir += "_lay%d_lrc%g_lra%g" % (args.layers, args.lrc, args.lra) | |
| if args.restart or args.clean: | |
| os.system("rm -rf " + "runs/" + args.model_dir) | |
| if args.clean: | |
| exit() | |
| dpath = getOrbiterDataset(args.dataset, '/data/orbiter/datasets/nerf_llff_data/') | |
| # dpath = '/data/orbiter/datasets/' + args.dataset | |
| if os.path.exists(dpath + "/ref_image.txt"): | |
| with open(dpath + "/ref_image.txt", "r") as fi: | |
| args.ref_img = str(fi.readline().strip()) | |
| if args.scale == -1: | |
| sfm = SfMData(dpath, | |
| ref_img=args.ref_img, | |
| dmin=args.dmin, | |
| dmax=args.dmax, | |
| scale=1) | |
| args.scale = args.deepview_width / sfm.ref_cam['width'] | |
| else: | |
| if args.scale == -1: | |
| sfm = SfMData(dpath, | |
| ref_img=args.ref_img, | |
| dmin=args.dmin, | |
| dmax=args.dmax, | |
| scale=1) | |
| args.scale = args.nerf_width / sfm.ref_cam['width'] | |
| dataset = OrbiterDataset(dpath, ref_img=args.ref_img, scale=args.scale, | |
| dmin=args.dmin, | |
| dmax=args.dmax, | |
| invz=args.invz, | |
| neighbor=args.neighbor, | |
| img_wildcard=args.img_wildcard) | |
| pt.manual_seed(0) | |
| if args.random_split: | |
| sampler_train, sampler_val = generateSubsetSamplers(len(dataset), ratio=args.train_ratio) | |
| dataloader_train = DataLoader(dataset, batch_size=1, sampler=sampler_train) | |
| dataloader_val = DataLoader(dataset, batch_size=1, sampler=sampler_val) | |
| print('TRAINING IMAGES: {}'.format(len(dataloader_train))) | |
| print('VALIDATE IMAGES: {}'.format(len(dataloader_val))) | |
| else: | |
| def get_indices(ty): | |
| if os.path.exists(dpath + "/{}_image.txt".format(ty)): | |
| data = [] | |
| with open(dpath + "/{}_image.txt".format(ty), "r") as fi: | |
| for line in fi.readlines(): | |
| count = 0 | |
| for img in dataset.imgs: | |
| if line.strip() in img['path'] and args.img_wildcard in img['path']: | |
| data.append(count) | |
| break | |
| count += 1 | |
| return data | |
| else: | |
| raise('No CONFIG TRAINING FILE') | |
| if os.path.exists(os.path.join(dpath,'poses_bounds.npy')): | |
| #LLFF dataset which is use every 8 images to be training data | |
| indices_total = list(range(len(dataset.imgs))) | |
| indices_val = indices_total[::8] | |
| indices_train = list(filter(lambda x: x not in indices_val, indices_total)) | |
| else: | |
| indices_train = get_indices('train') | |
| indices_val = get_indices('val') | |
| sampler_train = SubsetRandomSampler(indices_train) | |
| sampler_val = SubsetRandomSampler(indices_val) | |
| dataloader_train = DataLoader(dataset, batch_size=1, sampler = sampler_train) | |
| dataloader_val = DataLoader(dataset, batch_size=1, sampler = sampler_val) | |
| print('TRAINING IMAGES: {}'.format(len(dataloader_train))) | |
| print('VALIDATE IMAGES: {}'.format(len(dataloader_val))) | |
| Network = Network2 | |
| #if 'room' in args.dataset: | |
| # dataset.sfm.dmin = 1.2 | |
| # dataset.sfm.dmax = 20.0 | |
| model = Network((args.layers, | |
| 4, | |
| dataset.sfm.ref_cam['height'] + args.offset * 2, | |
| dataset.sfm.ref_cam['width'] + args.offset * 2, | |
| ), dataset.sfm).cuda() | |
| runpath = "runs/" | |
| writer = SummaryWriter(runpath + args.model_dir) | |
| ckpt = runpath + args.model_dir + "/ckpt.pt" | |
| checkpoint = None | |
| lr = args.lr | |
| if args.constrgblr: | |
| optimizer = pt.optim.Adam([ | |
| {'params': model.seq1.parameters(), 'lr': lr}, | |
| {'params': model.mpi_c, 'lr': lr}]) | |
| else: | |
| optimizer = pt.optim.Adam(model.parameters(), lr=lr) | |
| with open(runpath + args.model_dir + "/args.json", 'w') as out: | |
| json.dump(vars(args), out, indent=2, sort_keys=True) | |
| os.system("cp " + os.path.abspath(__file__) + " " + runpath + args.model_dir + "/") | |
| start_epoch = 0 | |
| if args.pretrained != "": | |
| checkpoint = pt.load(runpath + args.pretrained + "/ckpt.pt") | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| print("Loading %s model at epoch %d" % (args.pretrained, start_epoch)) | |
| if os.path.exists(ckpt): | |
| checkpoint = pt.load(ckpt) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| print("Loading model at epoch %d" % start_epoch) | |
| step = start_epoch * len(sampler_train) | |
| if args.epochs < 0 and args.steps < 0: | |
| print("Need to specify epochs or steps") | |
| exit() | |
| if args.epochs < 0: | |
| args.epochs = int(np.ceil(args.steps / len(sampler_train))) | |
| ts = TrainingStatus(num_steps=args.epochs * len(sampler_train)) | |
| if args.predict: | |
| generateAlpha(model, dataset, dataloader_val,runpath) | |
| exit() | |
| relative_step = 0 | |
| for epoch in range(start_epoch, args.epochs): | |
| epoch_loss_total = 0 | |
| epoch_mse = 0 | |
| model.train() | |
| for i, feature in enumerate(dataloader_train): | |
| ts.tic() | |
| optimizer.zero_grad() | |
| downscale = int(args.mlp_downscale) | |
| output_shape = feature['image'].shape[-2:] | |
| if False and args.adaptive: | |
| gauss = pt.clamp(kornia.filters.gaussian_blur2d(feature['image'], (3, 3), (0.8, 0.8)), 0, 1) | |
| # h, w | |
| edge = pt.mean(kornia.filters.sobel(gauss), 1) | |
| prob = edge.view(-1) * args.adaptive_scale + args.adaptive_bias | |
| prob = prob / pt.sum(prob) | |
| # np.random.choice() probabilies never summed up to one due to floating point precision! | |
| # selection = pt.tensor(np.random.choice(output_shape[0] * output_shape[1], args.ray, replace=False, p=prob)) | |
| selection = torch_sampling.choice(pt.tensor(list(range(output_shape[0] * output_shape[1]))), args.ray, True, prob) | |
| else: | |
| selection = pt.tensor(np.random.choice((output_shape[0] - 1) * (output_shape[1] -1), int(args.ray/3), replace=False)) | |
| # selection = pt.randperm(output_shape[0] * output_shape[1])[:20000] | |
| real_select = pt.zeros(3 * int(args.ray/3)).to(pt.long) | |
| real_select[0::3] = selection | |
| real_select[1::3] = selection + 1 | |
| real_select[2::3] = selection + output_shape[1] | |
| gt = feature['image'] | |
| gt = gt.view(gt.shape[0], gt.shape[1], gt.shape[2] * gt.shape[3]) | |
| gt = gt[:, :, real_select, None].cuda() | |
| output, mask = model(dataset.sfm, feature, output_shape, real_select, | |
| stoc=np.random.uniform() if args.stochastic else 0) | |
| descaling = output_shape[0] * output_shape[1] / (output_shape[0] * output_shape[1]) | |
| mse = pt.mean((mask * (output - gt))**2) * descaling | |
| if args.nomask: | |
| loss_recon = args.l1 * pt.mean(pt.abs(output - gt)) | |
| else: | |
| loss_recon = args.l1 * pt.mean(pt.abs(mask * (output - gt))) | |
| if args.partialtvc: | |
| tvc = args.tvc * model.partialtvc | |
| else: | |
| tvc = args.tvc * pt.mean(totalVariation(pt.sigmoid(model.mpi_c[:, :3]))) | |
| loss_total = (loss_recon + tvc) | |
| if args.gradloss > 0: | |
| oy = output[:, :, 0::3, :] - output[:, :, 1::3, :] | |
| ox = output[:, :, 0::3, :] - output[:, :, 2::3, :] | |
| gy = gt[:, :, 0::3, :] - gt[:, :, 1::3, :] | |
| gx = gt[:, :, 0::3, :] - gt[:, :, 2::3, :] | |
| masky = pt.abs(gy) < 0.005 | |
| maskx = pt.abs(gx) < 0.005 | |
| loss_total = loss_total + args.gradloss * (pt.mean(maskx * pt.abs(ox - gx)) + pt.mean(masky * pt.abs(oy - gy))) | |
| epoch_loss_total += loss_total | |
| epoch_mse += mse | |
| loss_total.backward() | |
| optimizer.step() | |
| if step % args.tb_saveimage == 0: | |
| # writer.add_image('images/0_gt', pt.cat([gt[0], output[0]], 1), step) | |
| # writer.add_image('images/1_alpha', make_grid(pt.sigmoid(model.mpi_a - args.sigmoid_offset)[::int(args.layers * args.sublayers / 8)], 4), step) | |
| writer.add_image('images/2_mpic', make_grid(F.interpolate(pt.sigmoid(model.mpi_c), (int(model.mpi_c.shape[-2] * 0.3), int(model.mpi_c.shape[-1] * 0.3)), mode='area'), 4), step) | |
| if step % args.tb_savempi == 0 and step > 0: | |
| generateAlpha(model, dataset, dataloader_val, runpath, "/%06d" % step) | |
| pt.cuda.empty_cache() | |
| step += 1 | |
| relative_step += 1 | |
| if step % args.decay_step == 0: | |
| lr *= args.decay_rate | |
| optimizer.param_groups[0]['lr'] = lr | |
| print(ts.toc(step, loss_total.item())) | |
| ts.tic() | |
| epoch_loss_total /= len(sampler_train) | |
| epoch_mse /= len(sampler_train) | |
| writer.add_scalar('loss/total', epoch_loss_total, step) | |
| writer.add_scalar('loss/mse', epoch_mse, step) | |
| if (epoch+1) % 20 == 0 or epoch == args.epochs-1: | |
| print("checkpointing model...") | |
| pt.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| }, runpath + args.model_dir + "/ckpt.pt") | |
| print('Finished Training') | |
| generateAlpha(model, dataset, dataloader_val, runpath) | |
| render_video(model, dataset, args.ray, os.path.join(runpath, 'video_output', args.model_dir)) | |
| def main(): | |
| train() | |
| if __name__ == "__main__": | |
| sys.excepthook = colored_hook(os.path.dirname(os.path.realpath(__file__))) | |
| main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment