Created
August 4, 2020 06:51
-
-
Save huchenxucs/d32d34e279868547c46d0fc88a656349 to your computer and use it in GitHub Desktop.
encoder-decoder self attention mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from modules.transformer_tts import TransformerEncoder, TransformerDecoder | |
from modules.operations import * | |
from modules.tts_modules import FastspeechDecoder, RefEncoder, DurationPredictor, LengthRegulator, PitchPredictor, \ | |
TacotronDecoder, EnergyPredictor | |
from modules.pos_embed import RelativePositionBias | |
from utils.world_utils import f0_to_coarse_torch, restore_pitch | |
from utils.tts_utils import sequence_mask | |
class Fastspeech(nn.Module): | |
def __init__(self, arch, dictionary): | |
super().__init__() | |
self.dictionary = dictionary | |
self.padding_idx = dictionary.pad() | |
if isinstance(arch, str): | |
self.arch = list(map(int, arch.strip().split())) | |
else: | |
assert isinstance(arch, (list, tuple)) | |
self.arch = arch | |
self.enc_layers = hparams['enc_layers'] | |
self.dec_layers = hparams['dec_layers'] | |
self.enc_arch = self.arch[:self.enc_layers] | |
self.dec_arch = self.arch[self.enc_layers:self.enc_layers + self.dec_layers] | |
self.hidden_size = hparams['hidden_size'] | |
self.encoder_embed_tokens = nn.Embedding(len(self.dictionary), self.hidden_size, self.padding_idx) | |
self.encoder = TransformerEncoder(self.enc_arch, self.encoder_embed_tokens, last_ln=not hparams['cond_ln'], | |
padding_type=hparams['enc_ffn_padding']) | |
self.decoder = FastspeechDecoder(self.dec_arch, padding_type=hparams['dec_ffn_padding']) | |
self.mel_out = Linear(self.hidden_size, hparams['audio_num_mel_bins'], bias=True) | |
self.spk_embed_proj = Linear(256, self.hidden_size, bias=True) | |
if hparams['use_rel_pos_embed']: | |
self.enc_rel_pos_embed = RelativePositionBias(**hparams['enc_rel_pos_embed_params']) | |
self.dec_rel_pos_embed = RelativePositionBias(**hparams['dec_rel_pos_embed_params']) | |
if hparams['cond_ln']: | |
self.cond_ln_w_proj = nn.Linear(self.hidden_size, self.hidden_size) | |
self.cond_ln_b_proj = nn.Linear(self.hidden_size, self.hidden_size) | |
self.predict_dur = hparams['predict_dur'] | |
if hparams['predict_dur']: | |
self.dur_predictor = DurationPredictor(self.hidden_size, n_chans=self.hidden_size, | |
dropout_rate=hparams['dur_drop'], padding=hparams['enc_ffn_padding'], | |
kernel_size=hparams['dur_kernel_size']) | |
self.length_regulator = LengthRegulator() | |
if hparams['use_pitch_embed']: | |
if hparams['pitch_embed_type'] == 'discrete': | |
self.pitch_embed = nn.Embedding(300, self.hidden_size, self.padding_idx) | |
else: | |
self.pitch_embed = nn.Conv1d(1, self.hidden_size, 9, padding=4) | |
self.pitch_do = nn.Dropout(hparams['pitch_dropout']) | |
if hparams['predict_pitch']: | |
self.pitch_predictor = PitchPredictor( | |
self.hidden_size, n_chans=self.hidden_size, dropout_rate=hparams['pitch_drop'], | |
n_layers=hparams['pitch_layer_num'], padding=hparams['dec_ffn_padding'], | |
kernel_size=hparams['pitch_kernel_size']) | |
if hparams['use_energy_embed']: | |
self.energy_predictor = EnergyPredictor( | |
self.hidden_size, n_chans=self.hidden_size, dropout_rate=0.5, odim=1, | |
padding=hparams['dec_ffn_padding'], kernel_size=hparams['energy_kernel_size']) | |
if hparams['energy_embed_type'] == 'discrete': | |
self.energy_embed = nn.Embedding(256, self.hidden_size, self.padding_idx) | |
else: | |
self.energy_embed = nn.Conv1d(1, self.hidden_size, 9, padding=4) | |
self.energy_do = nn.Dropout(0.5) | |
def forward(self, src_tokens, mel2ph, spk_embed=None, | |
ref_mels=None, pitch=None, uv=None, energy=None, prev_output_mels=None, | |
word_bdr=None, word_bdr_mel=None, word_split_point=None): | |
""" | |
:param src_tokens: [B, T] | |
:param mel2ph: when mel2ph is none, it is in inference mode | |
:param spk_embed: | |
:param ref_mels: | |
:word_bdr: word_bdr | |
[ 3, 3, 3, 8, 8, 8, 8, 8, 12, 12, 12, 12, | |
21, 21, 21, 21, 21, 21, 21, 21, 21, 0, 0, 0, 0] | |
:word_split_point: [0, 3, 8, 12, 21, -10] padding is -10 | |
:return: { | |
'mel_out': [B, T_s, 80], 'dur': [B, T_t], | |
'w_st_pred': [heads, B, tokens], 'w_st': [heads, B, tokens], | |
'encoder_out_noref': [B, T_t, H] | |
} | |
""" | |
ret = {} | |
bsz, src_T = src_tokens.size() | |
src_attn_mask_float = None | |
src_vis_mask = None | |
infer_mode = True if mel2ph is None else False | |
if hparams['enc_ffn_padding'] == 'MASK' or (hparams["streaming_mode"] and hparams["word_enc_mask"]): | |
assert word_bdr is not None | |
src_attn_mask = sequence_mask(word_bdr, word_bdr.shape[1]) # [B, T, T] | |
src_vis_mask = src_attn_mask | |
if hparams['word_enc_mask']: | |
src_attn_mask = ~(src_attn_mask.bool()) | |
# 防止最后的几行全是mask | |
triu = torch.triu(src_tokens.new_ones(src_T, src_T)).bool() | |
# import ipdb; ipdb.set_trace() | |
src_attn_mask &= triu.unsqueeze(0) | |
src_attn_mask_float = src_tokens.new_zeros(bsz, src_T, src_T).float()\ | |
.masked_fill_(src_attn_mask, float('-inf')) | |
src_attn_mask_float = src_attn_mask_float.repeat_interleave(2, dim=0) # (N*num_heads, L, S) | |
if hparams['use_rel_pos_embed']: | |
enc_rel_pos_bias = self.enc_rel_pos_embed(src_T, src_T).repeat(bsz, 1, 1, 1) # (bsz, num_heads, qlen, klen) | |
enc_rel_pos_bias = enc_rel_pos_bias.reshape(-1, src_T, src_T) # (bsz * num_heads, qlen, klen) | |
if src_attn_mask_float is None: | |
src_attn_mask_float = enc_rel_pos_bias | |
else: | |
src_attn_mask_float += enc_rel_pos_bias | |
encoder_outputs = self.encoder(src_tokens, attn_mask=src_attn_mask_float, vis_mask=src_vis_mask) | |
encoder_out = encoder_outputs['encoder_out'] # [T, B, C] | |
src_nonpadding = (src_tokens > 0).float().permute(1, 0)[:, :, None] | |
if hparams['use_spk_embed'] and spk_embed is not None: | |
spk_embed = self.spk_embed_proj(spk_embed)[None, :, :] | |
if not hparams['cond_ln']: | |
encoder_out += spk_embed | |
else: | |
mean = encoder_out.mean(-1, keepdims=True) | |
variance = ((encoder_out - mean) ** 2).mean(-1, keepdims=True) | |
norm_x = (encoder_out - mean) * torch.rsqrt(variance + 1e-6) | |
encoder_out = norm_x * self.cond_ln_w_proj(spk_embed) + self.cond_ln_b_proj(spk_embed) | |
encoder_out = encoder_out * src_nonpadding # [T, B, C] | |
if mel2ph is None: | |
dur = self.dur_predictor.inference(encoder_out.transpose(0, 1), src_tokens == 0, vis_mask=src_vis_mask) | |
mel2ph = self.length_regulator(dur, (src_tokens != 0).sum(-1))[..., 0] | |
ret['pred_dur'] = dur | |
ret['mel2ph'] = mel2ph | |
elif self.predict_dur: | |
ret['dur'] = self.dur_predictor(encoder_out.transpose(0, 1), src_tokens == 0, vis_mask=src_vis_mask) | |
# add energy embed | |
if hparams['use_energy_embed'] and not hparams['mel_energy']: | |
ret['energy_pred'] = energy_pred = torch.clamp(self.energy_predictor(encoder_out.transpose(0, 1), | |
vis_mask=src_vis_mask)[:, :, 0], min=0.0) | |
# import ipdb; ipdb.set_trace() | |
if energy is None: | |
# import ipdb; ipdb.set_trace() | |
energy = energy_pred * hparams['energy_factor'] | |
# energy_factor = torch.arange() | |
ret['energy_pred'] = energy | |
energy = energy.transpose(0, 1) # [B, T] -> [T, B] | |
if hparams['energy_embed_type'] == 'discrete': | |
energy = torch.clamp(energy * 256 // 4, max=255).long() | |
# import ipdb; | |
# ipdb.set_trace() | |
energy_embed = self.energy_embed(energy) # [T, B, H] | |
else: | |
energy_embed = self.energy_embed(energy[:, None, :].transpose(0, 2)).permute(2, 0, 1) | |
encoder_out = encoder_out + self.energy_do(energy_embed) | |
encoder_out = encoder_out * src_nonpadding # [T, B, C] | |
# expand encoder out to make decoder inputs | |
decoder_inp = F.pad(encoder_out, [0, 0, 0, 0, 1, 0]) | |
mel2ph_ = mel2ph.permute([1, 0])[..., None].repeat([1, 1, encoder_out.shape[-1]]).contiguous() | |
decoder_inp = torch.gather(decoder_inp, 0, mel2ph_).transpose(0, 1) # [B, T, H] | |
decoder_inp_origin = decoder_inp # [B, T, H] | |
tgt_T = decoder_inp.size(1) | |
tgt_attn_mask_float = None | |
tgt_vis_mask = None | |
if hparams['dec_ffn_padding'] == 'MASK' or (hparams["streaming_mode"] and hparams["word_dec_mask"]): | |
assert word_bdr is not None | |
if not infer_mode: | |
assert word_bdr_mel is not None | |
tgt_attn_mask = sequence_mask(word_bdr_mel, word_bdr_mel.shape[1]) # [B, T, T] | |
else: | |
# inference, | |
# assert bsz == 1, f"batch size = {bsz}" # bsz 不是 1 | |
assert word_split_point is not None | |
tgt_attn_mask = decoder_inp.new_zeros(bsz, tgt_T, tgt_T) | |
for bi in range(bsz): | |
last_bdr_mel = 0 | |
for wi in range(1, word_split_point.size(1)): | |
if word_split_point[bi, wi] == -10: | |
break | |
# compatible to predicted duration when doing inference | |
# last_bdr_mel = torch.where(mel2ph[bi] > word_split_point[bi, wi-1])[0][0] | |
bdr_mel = torch.where(mel2ph[bi] > word_split_point[bi, wi])[0] | |
if len(bdr_mel) > 0: | |
bdr_mel = bdr_mel[0] | |
else: | |
bdr_mel = torch.where(mel2ph[bi] > 0)[0][-1] + 1 | |
tgt_attn_mask[bi, last_bdr_mel:bdr_mel, 0:bdr_mel] = 1 | |
last_bdr_mel = bdr_mel | |
tgt_vis_mask = tgt_attn_mask | |
if hparams['word_dec_mask']: | |
tgt_attn_mask = ~(tgt_attn_mask.bool()) | |
# 防止最后的几行全是mask | |
triu = torch.triu(decoder_inp.new_ones(tgt_T, tgt_T)).bool() | |
tgt_attn_mask &= triu.unsqueeze(0) | |
tgt_attn_mask_float = decoder_inp.new_zeros(bsz, tgt_T, tgt_T).float() \ | |
.masked_fill_(tgt_attn_mask, float('-inf')) | |
tgt_attn_mask_float = tgt_attn_mask_float.repeat_interleave(2, dim=0) # (N*num_heads, L, S) | |
if hparams['use_rel_pos_embed']: | |
dec_rel_pos_bias = self.dec_rel_pos_embed(tgt_T, tgt_T).repeat(bsz, 1, 1, 1) # (bsz, num_heads, qlen, klen) | |
dec_rel_pos_bias = dec_rel_pos_bias.reshape(-1, tgt_T, tgt_T) # (bsz * num_heads, qlen, klen) | |
if tgt_attn_mask_float is None: | |
tgt_attn_mask_float = dec_rel_pos_bias | |
else: | |
tgt_attn_mask_float += dec_rel_pos_bias | |
# add pitch embed | |
if hparams['use_pitch_embed']: | |
change_pitch_flag = False | |
if hparams['predict_pitch']: | |
ret['pitch_logits'] = pitch_logits = self.pitch_predictor(decoder_inp_origin, vis_mask=tgt_vis_mask) | |
if pitch is None: | |
change_pitch_flag = True | |
pitch = pitch_logits[:, :, 0] | |
uv = pitch_logits[:, :, 1] > 0 | |
pitch_padding = (mel2ph == 0) | |
else: | |
pitch_padding = pitch == -200 | |
pitch = restore_pitch(pitch, uv, hparams, pitch_padding=pitch_padding, change_pitch_flag=change_pitch_flag) | |
ret['pitch'] = pitch | |
if hparams['pitch_embed_type'] == 'discrete': | |
# import ipdb; ipdb.set_trace() | |
pitch = f0_to_coarse_torch(pitch) | |
ret['pitch_coarse'] = pitch | |
pitch_embed = self.pitch_embed(pitch) | |
else: | |
pitch[pitch == 1] = 0 | |
pitch = (1 + pitch / 700).log() | |
pitch = pitch[:, None, :] # [B, 1, T] | |
pitch_embed = self.pitch_embed(pitch).transpose(1, 2) | |
decoder_inp = decoder_inp + self.pitch_do(pitch_embed) | |
# add energy embed | |
if hparams['use_energy_embed'] and hparams['mel_energy']: | |
ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp_origin, vis_mask=tgt_vis_mask)[:, :, 0] | |
if energy is None: | |
energy = energy_pred # [B, T] | |
if hparams['change_energy']: | |
energy_factor = torch.arange(energy.size(1), device=energy.device)[None, :].repeat(energy.size(0), 1) | |
energy_factor = hparams['energy_amp'] * \ | |
torch.sin(energy_factor * (2 * np.pi / hparams['energy_period'])) + 1 | |
energy *= energy_factor | |
ret['energy_pred'] = energy | |
if hparams['energy_embed_type'] == 'discrete': | |
energy = torch.clamp(energy * 256 // 4, max=255).long() | |
energy_embed = self.energy_embed(energy) | |
else: | |
energy_embed = self.energy_embed(energy[:, None, :]).transpose(1, 2) | |
decoder_inp = decoder_inp + self.energy_do(energy_embed) | |
decoder_inp = decoder_inp * (mel2ph != 0).float()[:, :, None] | |
x = self.decoder(decoder_inp, attn_mask=tgt_attn_mask_float, vis_mask=tgt_vis_mask) | |
x = self.mel_out(x) | |
x = x * (mel2ph != 0).float()[:, :, None] | |
ret['mel_out'] = x | |
return ret |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment