Skip to content

Instantly share code, notes, and snippets.

@huchenxucs
Created August 4, 2020 06:51
Show Gist options
  • Save huchenxucs/d32d34e279868547c46d0fc88a656349 to your computer and use it in GitHub Desktop.
Save huchenxucs/d32d34e279868547c46d0fc88a656349 to your computer and use it in GitHub Desktop.
encoder-decoder self attention mask
import numpy as np
from modules.transformer_tts import TransformerEncoder, TransformerDecoder
from modules.operations import *
from modules.tts_modules import FastspeechDecoder, RefEncoder, DurationPredictor, LengthRegulator, PitchPredictor, \
TacotronDecoder, EnergyPredictor
from modules.pos_embed import RelativePositionBias
from utils.world_utils import f0_to_coarse_torch, restore_pitch
from utils.tts_utils import sequence_mask
class Fastspeech(nn.Module):
def __init__(self, arch, dictionary):
super().__init__()
self.dictionary = dictionary
self.padding_idx = dictionary.pad()
if isinstance(arch, str):
self.arch = list(map(int, arch.strip().split()))
else:
assert isinstance(arch, (list, tuple))
self.arch = arch
self.enc_layers = hparams['enc_layers']
self.dec_layers = hparams['dec_layers']
self.enc_arch = self.arch[:self.enc_layers]
self.dec_arch = self.arch[self.enc_layers:self.enc_layers + self.dec_layers]
self.hidden_size = hparams['hidden_size']
self.encoder_embed_tokens = nn.Embedding(len(self.dictionary), self.hidden_size, self.padding_idx)
self.encoder = TransformerEncoder(self.enc_arch, self.encoder_embed_tokens, last_ln=not hparams['cond_ln'],
padding_type=hparams['enc_ffn_padding'])
self.decoder = FastspeechDecoder(self.dec_arch, padding_type=hparams['dec_ffn_padding'])
self.mel_out = Linear(self.hidden_size, hparams['audio_num_mel_bins'], bias=True)
self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
if hparams['use_rel_pos_embed']:
self.enc_rel_pos_embed = RelativePositionBias(**hparams['enc_rel_pos_embed_params'])
self.dec_rel_pos_embed = RelativePositionBias(**hparams['dec_rel_pos_embed_params'])
if hparams['cond_ln']:
self.cond_ln_w_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.cond_ln_b_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.predict_dur = hparams['predict_dur']
if hparams['predict_dur']:
self.dur_predictor = DurationPredictor(self.hidden_size, n_chans=self.hidden_size,
dropout_rate=hparams['dur_drop'], padding=hparams['enc_ffn_padding'],
kernel_size=hparams['dur_kernel_size'])
self.length_regulator = LengthRegulator()
if hparams['use_pitch_embed']:
if hparams['pitch_embed_type'] == 'discrete':
self.pitch_embed = nn.Embedding(300, self.hidden_size, self.padding_idx)
else:
self.pitch_embed = nn.Conv1d(1, self.hidden_size, 9, padding=4)
self.pitch_do = nn.Dropout(hparams['pitch_dropout'])
if hparams['predict_pitch']:
self.pitch_predictor = PitchPredictor(
self.hidden_size, n_chans=self.hidden_size, dropout_rate=hparams['pitch_drop'],
n_layers=hparams['pitch_layer_num'], padding=hparams['dec_ffn_padding'],
kernel_size=hparams['pitch_kernel_size'])
if hparams['use_energy_embed']:
self.energy_predictor = EnergyPredictor(
self.hidden_size, n_chans=self.hidden_size, dropout_rate=0.5, odim=1,
padding=hparams['dec_ffn_padding'], kernel_size=hparams['energy_kernel_size'])
if hparams['energy_embed_type'] == 'discrete':
self.energy_embed = nn.Embedding(256, self.hidden_size, self.padding_idx)
else:
self.energy_embed = nn.Conv1d(1, self.hidden_size, 9, padding=4)
self.energy_do = nn.Dropout(0.5)
def forward(self, src_tokens, mel2ph, spk_embed=None,
ref_mels=None, pitch=None, uv=None, energy=None, prev_output_mels=None,
word_bdr=None, word_bdr_mel=None, word_split_point=None):
"""
:param src_tokens: [B, T]
:param mel2ph: when mel2ph is none, it is in inference mode
:param spk_embed:
:param ref_mels:
:word_bdr: word_bdr
[ 3, 3, 3, 8, 8, 8, 8, 8, 12, 12, 12, 12,
21, 21, 21, 21, 21, 21, 21, 21, 21, 0, 0, 0, 0]
:word_split_point: [0, 3, 8, 12, 21, -10] padding is -10
:return: {
'mel_out': [B, T_s, 80], 'dur': [B, T_t],
'w_st_pred': [heads, B, tokens], 'w_st': [heads, B, tokens],
'encoder_out_noref': [B, T_t, H]
}
"""
ret = {}
bsz, src_T = src_tokens.size()
src_attn_mask_float = None
src_vis_mask = None
infer_mode = True if mel2ph is None else False
if hparams['enc_ffn_padding'] == 'MASK' or (hparams["streaming_mode"] and hparams["word_enc_mask"]):
assert word_bdr is not None
src_attn_mask = sequence_mask(word_bdr, word_bdr.shape[1]) # [B, T, T]
src_vis_mask = src_attn_mask
if hparams['word_enc_mask']:
src_attn_mask = ~(src_attn_mask.bool())
# 防止最后的几行全是mask
triu = torch.triu(src_tokens.new_ones(src_T, src_T)).bool()
# import ipdb; ipdb.set_trace()
src_attn_mask &= triu.unsqueeze(0)
src_attn_mask_float = src_tokens.new_zeros(bsz, src_T, src_T).float()\
.masked_fill_(src_attn_mask, float('-inf'))
src_attn_mask_float = src_attn_mask_float.repeat_interleave(2, dim=0) # (N*num_heads, L, S)
if hparams['use_rel_pos_embed']:
enc_rel_pos_bias = self.enc_rel_pos_embed(src_T, src_T).repeat(bsz, 1, 1, 1) # (bsz, num_heads, qlen, klen)
enc_rel_pos_bias = enc_rel_pos_bias.reshape(-1, src_T, src_T) # (bsz * num_heads, qlen, klen)
if src_attn_mask_float is None:
src_attn_mask_float = enc_rel_pos_bias
else:
src_attn_mask_float += enc_rel_pos_bias
encoder_outputs = self.encoder(src_tokens, attn_mask=src_attn_mask_float, vis_mask=src_vis_mask)
encoder_out = encoder_outputs['encoder_out'] # [T, B, C]
src_nonpadding = (src_tokens > 0).float().permute(1, 0)[:, :, None]
if hparams['use_spk_embed'] and spk_embed is not None:
spk_embed = self.spk_embed_proj(spk_embed)[None, :, :]
if not hparams['cond_ln']:
encoder_out += spk_embed
else:
mean = encoder_out.mean(-1, keepdims=True)
variance = ((encoder_out - mean) ** 2).mean(-1, keepdims=True)
norm_x = (encoder_out - mean) * torch.rsqrt(variance + 1e-6)
encoder_out = norm_x * self.cond_ln_w_proj(spk_embed) + self.cond_ln_b_proj(spk_embed)
encoder_out = encoder_out * src_nonpadding # [T, B, C]
if mel2ph is None:
dur = self.dur_predictor.inference(encoder_out.transpose(0, 1), src_tokens == 0, vis_mask=src_vis_mask)
mel2ph = self.length_regulator(dur, (src_tokens != 0).sum(-1))[..., 0]
ret['pred_dur'] = dur
ret['mel2ph'] = mel2ph
elif self.predict_dur:
ret['dur'] = self.dur_predictor(encoder_out.transpose(0, 1), src_tokens == 0, vis_mask=src_vis_mask)
# add energy embed
if hparams['use_energy_embed'] and not hparams['mel_energy']:
ret['energy_pred'] = energy_pred = torch.clamp(self.energy_predictor(encoder_out.transpose(0, 1),
vis_mask=src_vis_mask)[:, :, 0], min=0.0)
# import ipdb; ipdb.set_trace()
if energy is None:
# import ipdb; ipdb.set_trace()
energy = energy_pred * hparams['energy_factor']
# energy_factor = torch.arange()
ret['energy_pred'] = energy
energy = energy.transpose(0, 1) # [B, T] -> [T, B]
if hparams['energy_embed_type'] == 'discrete':
energy = torch.clamp(energy * 256 // 4, max=255).long()
# import ipdb;
# ipdb.set_trace()
energy_embed = self.energy_embed(energy) # [T, B, H]
else:
energy_embed = self.energy_embed(energy[:, None, :].transpose(0, 2)).permute(2, 0, 1)
encoder_out = encoder_out + self.energy_do(energy_embed)
encoder_out = encoder_out * src_nonpadding # [T, B, C]
# expand encoder out to make decoder inputs
decoder_inp = F.pad(encoder_out, [0, 0, 0, 0, 1, 0])
mel2ph_ = mel2ph.permute([1, 0])[..., None].repeat([1, 1, encoder_out.shape[-1]]).contiguous()
decoder_inp = torch.gather(decoder_inp, 0, mel2ph_).transpose(0, 1) # [B, T, H]
decoder_inp_origin = decoder_inp # [B, T, H]
tgt_T = decoder_inp.size(1)
tgt_attn_mask_float = None
tgt_vis_mask = None
if hparams['dec_ffn_padding'] == 'MASK' or (hparams["streaming_mode"] and hparams["word_dec_mask"]):
assert word_bdr is not None
if not infer_mode:
assert word_bdr_mel is not None
tgt_attn_mask = sequence_mask(word_bdr_mel, word_bdr_mel.shape[1]) # [B, T, T]
else:
# inference,
# assert bsz == 1, f"batch size = {bsz}" # bsz 不是 1
assert word_split_point is not None
tgt_attn_mask = decoder_inp.new_zeros(bsz, tgt_T, tgt_T)
for bi in range(bsz):
last_bdr_mel = 0
for wi in range(1, word_split_point.size(1)):
if word_split_point[bi, wi] == -10:
break
# compatible to predicted duration when doing inference
# last_bdr_mel = torch.where(mel2ph[bi] > word_split_point[bi, wi-1])[0][0]
bdr_mel = torch.where(mel2ph[bi] > word_split_point[bi, wi])[0]
if len(bdr_mel) > 0:
bdr_mel = bdr_mel[0]
else:
bdr_mel = torch.where(mel2ph[bi] > 0)[0][-1] + 1
tgt_attn_mask[bi, last_bdr_mel:bdr_mel, 0:bdr_mel] = 1
last_bdr_mel = bdr_mel
tgt_vis_mask = tgt_attn_mask
if hparams['word_dec_mask']:
tgt_attn_mask = ~(tgt_attn_mask.bool())
# 防止最后的几行全是mask
triu = torch.triu(decoder_inp.new_ones(tgt_T, tgt_T)).bool()
tgt_attn_mask &= triu.unsqueeze(0)
tgt_attn_mask_float = decoder_inp.new_zeros(bsz, tgt_T, tgt_T).float() \
.masked_fill_(tgt_attn_mask, float('-inf'))
tgt_attn_mask_float = tgt_attn_mask_float.repeat_interleave(2, dim=0) # (N*num_heads, L, S)
if hparams['use_rel_pos_embed']:
dec_rel_pos_bias = self.dec_rel_pos_embed(tgt_T, tgt_T).repeat(bsz, 1, 1, 1) # (bsz, num_heads, qlen, klen)
dec_rel_pos_bias = dec_rel_pos_bias.reshape(-1, tgt_T, tgt_T) # (bsz * num_heads, qlen, klen)
if tgt_attn_mask_float is None:
tgt_attn_mask_float = dec_rel_pos_bias
else:
tgt_attn_mask_float += dec_rel_pos_bias
# add pitch embed
if hparams['use_pitch_embed']:
change_pitch_flag = False
if hparams['predict_pitch']:
ret['pitch_logits'] = pitch_logits = self.pitch_predictor(decoder_inp_origin, vis_mask=tgt_vis_mask)
if pitch is None:
change_pitch_flag = True
pitch = pitch_logits[:, :, 0]
uv = pitch_logits[:, :, 1] > 0
pitch_padding = (mel2ph == 0)
else:
pitch_padding = pitch == -200
pitch = restore_pitch(pitch, uv, hparams, pitch_padding=pitch_padding, change_pitch_flag=change_pitch_flag)
ret['pitch'] = pitch
if hparams['pitch_embed_type'] == 'discrete':
# import ipdb; ipdb.set_trace()
pitch = f0_to_coarse_torch(pitch)
ret['pitch_coarse'] = pitch
pitch_embed = self.pitch_embed(pitch)
else:
pitch[pitch == 1] = 0
pitch = (1 + pitch / 700).log()
pitch = pitch[:, None, :] # [B, 1, T]
pitch_embed = self.pitch_embed(pitch).transpose(1, 2)
decoder_inp = decoder_inp + self.pitch_do(pitch_embed)
# add energy embed
if hparams['use_energy_embed'] and hparams['mel_energy']:
ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp_origin, vis_mask=tgt_vis_mask)[:, :, 0]
if energy is None:
energy = energy_pred # [B, T]
if hparams['change_energy']:
energy_factor = torch.arange(energy.size(1), device=energy.device)[None, :].repeat(energy.size(0), 1)
energy_factor = hparams['energy_amp'] * \
torch.sin(energy_factor * (2 * np.pi / hparams['energy_period'])) + 1
energy *= energy_factor
ret['energy_pred'] = energy
if hparams['energy_embed_type'] == 'discrete':
energy = torch.clamp(energy * 256 // 4, max=255).long()
energy_embed = self.energy_embed(energy)
else:
energy_embed = self.energy_embed(energy[:, None, :]).transpose(1, 2)
decoder_inp = decoder_inp + self.energy_do(energy_embed)
decoder_inp = decoder_inp * (mel2ph != 0).float()[:, :, None]
x = self.decoder(decoder_inp, attn_mask=tgt_attn_mask_float, vis_mask=tgt_vis_mask)
x = self.mel_out(x)
x = x * (mel2ph != 0).float()[:, :, None]
ret['mel_out'] = x
return ret
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment