Last active
April 28, 2022 14:29
-
-
Save grey-area/98033b9708827f14a2a82d2022d1cbfa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
# NOTE: I've just put this here so that I don't have to import any other part of your code base | |
# to try out / run this model | |
control_signals_labels = ['rhand', 'lhand', 'head'] | |
residual_block_linear = 1024 | |
# NOTE: this is just because we're now doing 1d batch norm on a 3 dimensional tensor, | |
# and nn.BatchNorm1d assumes our feature dimension is the second dimension, not the third, | |
# so we have to shuffle and shuffle back | |
class BatchNorm1d(nn.Module): | |
def __init__(self, n_features): | |
super().__init__() | |
self.bn = nn.BatchNorm1d(n_features) | |
def forward(self, x): | |
return self.bn(x.transpose(1, 2)).transpose(1, 2) | |
class ResidualBlock(nn.Module): | |
def __init__(self, n_features): | |
super().__init__() | |
self.fc_layer1 = nn.Linear(n_features, n_features) | |
self.bn1 = BatchNorm1d(n_features) | |
self.relu = nn.LeakyReLU(inplace=True) | |
self.dropout = nn.Dropout(0.5) | |
self.fc_layer2 = nn.Linear(n_features, n_features) | |
self.bn2 = BatchNorm1d(n_features) | |
def forward(self, x): | |
residual = x | |
out = self.fc_layer1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.dropout(out) | |
out = self.fc_layer2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.dropout(out) | |
out += residual | |
return out | |
class Pose_Generator(nn.Module): | |
# NOTE: model no longer needs to know about seqlen | |
def __init__(self, n_control_features, n_angle_signal_features, no_of_angles): | |
super().__init__() | |
hidden_channels = 512 | |
num_layers = 1 | |
dropout = 0.0 | |
bidirectional = False | |
# NOTE: This wasn't a strictly necessary change, but I've removed the wrapper class around nn.GRU | |
self.control_flow_layer = nn.GRU(n_control_features, hidden_channels, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional) | |
self.angle_signal_flow_layer = nn.GRU(n_angle_signal_features, hidden_channels, num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional) | |
# NOTE: input is each frame of the output of the GRU, not the flattened whole output | |
self.e2d = nn.Sequential( | |
nn.Linear(hidden_channels * 2, residual_block_linear), | |
nn.LeakyReLU(inplace=True), | |
ResidualBlock(residual_block_linear), | |
ResidualBlock(residual_block_linear), | |
) | |
self.generator = nn.Sequential( | |
ResidualBlock(residual_block_linear), | |
ResidualBlock(residual_block_linear), | |
nn.Linear(residual_block_linear, no_of_angles) | |
) | |
# NOTE: control and angle hidden states are optionally passed in | |
def encoder2d(self, control_signal, angle_signal, control_hidden=None, angle_hidden=None): | |
cs, control_hidden = self.control_flow_layer(control_signal, control_hidden) | |
av, angle_hidden = self.angle_signal_flow_layer(angle_signal, angle_hidden) | |
# NOTE: no flattening here, the e2d and generator operate on every frame independently (after the GRU) | |
out = torch.cat((cs, av), dim=-1) | |
enc = self.e2d(out) | |
return enc, (control_hidden, angle_hidden) | |
# NOTE: control and angle hidden states are optionally passed in | |
def forward(self, control_signal, angle_signal, control_hidden=None, angle_hidden=None): | |
enc, hidden_state = self.encoder2d(control_signal, angle_signal, control_hidden, angle_hidden) | |
gen = self.generator(enc) | |
return gen, hidden_state | |
if __name__ == "__main__": | |
# NOTE: made up dimensionalities for playing with the model below | |
n_control_features = 10 | |
n_angle_signal_features = 3 | |
no_of_angles = 5 | |
seqlen = 15 | |
batch_size = 7 | |
control_in = torch.zeros(batch_size, seqlen, n_control_features) | |
angles_in = torch.zeros(batch_size, seqlen, n_angle_signal_features) | |
model = Pose_Generator(n_control_features, n_angle_signal_features, no_of_angles) | |
# NOTE: example of using the model during training | |
# NOTE: we have a target value for each of the 40 frames, not just the last, | |
# so effectively 40x as much training data! | |
# NOTE: during training, you never have to pass in the hidden states and you ignore the hidden states that come out | |
target = torch.zeros(batch_size, seqlen, no_of_angles) | |
out, _ = model(control_in, angles_in) | |
# NOTE: Prediction (this would be how the model would be used on the Unity side, with the initially 'blank' hidden state being passed in and returned in a loop) | |
# and inputs being passed in one frame at a time. | |
# NOTE: This way you can let whoever is handling the Unity side worry about normalization. | |
model = model.eval() | |
control_hidden = torch.zeros(1, 1, 512) | |
angle_hidden = torch.zeros(1, 1, 512) | |
for i in range(10): # This would be an ongoing loop | |
control = torch.zeros(1, 1, n_control_features) | |
angle = torch.zeros(1, 1, n_angle_signal_features) | |
frame_out, (control_hidden, angle_hidden) = model(control, angle, control_hidden, angle_hidden) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment