class Conv1dWithMask(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, w_init_gain='linear'): super(Conv1dWithMask, self).__init__() assert kernel_size > 1, f"Conv1dWithMask kernel size must greater than 1" self.kernel_size = kernel_size self.out_channels = out_channels self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, bias=bias) torch.nn.init.xavier_uniform_( self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, x, mask=None): """ :param x: [B, H, T] :param mask: [B, T, T], e.g.: tensor([[[1., 1., 0., 0., 0., 0., 0., 0.], [1., 1., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 0., 0., 0., 0.], [1., 1., 1., 1., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.]], ...]) :return: [B, H', T] """ if isinstance(x, list): assert len(x) == 2 x, mask = x[0], x[1] assert mask is not None x = x.permute(0, 2, 1) # [B, H, T] -> [B, T, H] kernel_size = self.kernel_size B, T, H = x.shape mask_pad = F.pad(mask, [kernel_size // 2, kernel_size // 2]) mask_pad_shift = torch.cat([mask_pad[:, :, :-1].reshape(B, -1), mask_pad[:, :, -1]], -1) mask_pad_shift = mask_pad_shift.reshape(B, T, -1)[:, :, :kernel_size] mask_pad_shift = mask_pad_shift.reshape(-1, 1, kernel_size).float() # [B*T, 1, K] x_pad = F.pad(x, [0, 0, kernel_size // 2, kernel_size // 2], value=0) # [B, T+K-1, H] x_unfold = x_pad.unfold(1, kernel_size, 1) # [B, T, H, K] x_unfold = x_unfold.reshape(-1, H, kernel_size) # [B*T, H, K] x_conv = self.conv(x_unfold * mask_pad_shift) # [B*T, H', 1] x_conv = x_conv.reshape(B, T, self.out_channels) # [B, T, H'] x_conv = x_conv.permute(0, 2, 1) # [B, H', T] return x_conv