class Conv1dWithMask(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, w_init_gain='linear'):
        super(Conv1dWithMask, self).__init__()
        assert kernel_size > 1, f"Conv1dWithMask kernel size must greater than 1"
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x, mask=None):
        """
        :param x: [B, H, T]
        :param mask: [B, T, T], e.g.:
                 tensor([[[1., 1., 0., 0., 0., 0., 0., 0.],
                         [1., 1., 0., 0., 0., 0., 0., 0.],
                         [1., 1., 1., 1., 0., 0., 0., 0.],
                         [1., 1., 1., 1., 0., 0., 0., 0.],
                         [1., 1., 1., 1., 1., 1., 1., 1.],
                         [1., 1., 1., 1., 1., 1., 1., 1.],
                         [1., 1., 1., 1., 1., 1., 1., 1.],
                         [1., 1., 1., 1., 1., 1., 1., 1.]], ...])
        :return: [B, H', T]
        """
        if isinstance(x, list):
            assert len(x) == 2
            x, mask = x[0], x[1]
        assert mask is not None
        x = x.permute(0, 2, 1)  # [B, H, T] -> [B, T, H]
        kernel_size = self.kernel_size
        B, T, H = x.shape

        mask_pad = F.pad(mask, [kernel_size // 2, kernel_size // 2])
        mask_pad_shift = torch.cat([mask_pad[:, :, :-1].reshape(B, -1), mask_pad[:, :, -1]], -1)
        mask_pad_shift = mask_pad_shift.reshape(B, T, -1)[:, :, :kernel_size]
        mask_pad_shift = mask_pad_shift.reshape(-1, 1, kernel_size).float()  # [B*T, 1, K]

        x_pad = F.pad(x, [0, 0, kernel_size // 2, kernel_size // 2], value=0)  # [B, T+K-1, H]
        x_unfold = x_pad.unfold(1, kernel_size, 1)  # [B, T, H, K]
        x_unfold = x_unfold.reshape(-1, H, kernel_size)  # [B*T, H, K]
        x_conv = self.conv(x_unfold * mask_pad_shift)  # [B*T, H', 1]
        x_conv = x_conv.reshape(B, T, self.out_channels)  # [B, T, H']
        x_conv = x_conv.permute(0, 2, 1)  # [B, H', T]
        return x_conv