Created
September 14, 2020 08:33
-
-
Save mkocabas/6e17e91222cc1492b9886a00cba6e9f1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
####### LOSS FUNCTION ####### | |
class MultivariateGaussianNegativeLogLikelihood(nn.Module): | |
def __init__(self): | |
super(MultivariateGaussianNegativeLogLikelihood, self).__init__() | |
def forward(self, pred_mean, pred_var, gt): | |
mu = pred_mean | |
logsigma = pred_var | |
mse = -0.5 * torch.sum(torch.square((gt - mu) / torch.exp(logsigma)), dim=1) | |
sigma_trace = -torch.sum(torch.exp(logsigma), dim=1) | |
log2pi = -0.5 * n_dims * np.log(2 * np.pi) | |
logger.debug(f'\nMSE: {mse.mean().item():.2f}' | |
f' Sigma: {sigma_trace.mean().item():.2f}' | |
f' log2pi:{log2pi.mean():.2f}') | |
log_likelihood = mse + sigma_trace + log2pi | |
return torch.mean(-log_likelihood) | |
####### HMR HEAD ####### | |
class hmr_head(nn.Module): | |
def __init__( | |
self, | |
num_input_features, | |
smpl_mean_params=SMPL_MEAN_PARAMS, | |
): | |
super(hmr_head, self).__init__() | |
npose = 24 * 6 | |
self.npose = npose | |
self.avgpool = nn.AvgPool2d(7, stride=1) | |
self.fc1 = nn.Linear(num_input_features + npose + 13, 1024) | |
self.drop1 = nn.Dropout() | |
self.fc2 = nn.Linear(1024, 1024) | |
self.drop2 = nn.Dropout() | |
# Double the MLP output for pose and shape | |
self.decpose = nn.Linear(1024, npose * 2) | |
self.decshape = nn.Linear(1024, 10 * 2) | |
self.deccam = nn.Linear(1024, 3) | |
mean_params = np.load(smpl_mean_params) | |
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) | |
self.register_buffer('init_pose', init_pose) | |
self.register_buffer('init_shape', init_shape) | |
self.register_buffer('init_cam', init_cam) | |
def forward( | |
self, | |
features, | |
init_pose=None, | |
init_shape=None, | |
init_cam=None, | |
n_iter=3 | |
): | |
batch_size = features.shape[0] | |
if init_pose is None: | |
init_pose = self.init_pose.expand(batch_size, -1) | |
if init_shape is None: | |
init_shape = self.init_shape.expand(batch_size, -1) | |
if init_cam is None: | |
init_cam = self.init_cam.expand(batch_size, -1) | |
xf = self.avgpool(features) | |
xf = xf.view(xf.size(0), -1) | |
pred_pose = init_pose | |
pred_shape = init_shape | |
pred_cam = init_cam | |
for i in range(n_iter): | |
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1) | |
xc = self.fc1(xc) | |
xc = self.drop1(xc) | |
xc = self.fc2(xc) | |
xc = self.drop2(xc) | |
pred_pose = self.decpose(xc)[:,:self.npose] + pred_pose | |
pred_shape = self.decshape(xc)[:,:10] + pred_shape | |
pred_cam = self.deccam(xc) + pred_cam | |
pred_pose_var = self.decpose(xc)[:,self.npose:] | |
pred_shape_var = self.decshape(xc)[:,10:] | |
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) | |
output = { | |
'pred_pose': pred_rotmat, | |
'pred_cam': pred_cam, | |
'pred_shape': pred_shape, | |
'pred_pose_6d': pred_pose, | |
'pred_pose_6d_var': pred_pose_var, | |
'pred_shape_var': pred_shape_var, | |
} | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment