Created
July 23, 2022 18:12
-
-
Save piraka9011/82f46c8479e0b90dbca19cac68e4759d to your computer and use it in GitHub Desktop.
Convert NeMo CitriNet to iOS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.quantization import quantize_dynamic | |
from torch.utils.mobile_optimizer import optimize_for_mobile | |
from nemo.collections.asr.models import EncDecCTCModelBPE | |
# from nemo.collections.asr.parts.preprocessing import FilterbankFeatures | |
from omegaconf import OmegaConf | |
import torch | |
import torchaudio | |
import math | |
import random | |
from typing import Dict, Union | |
# import librosa | |
import torch | |
import torch.nn as nn | |
import torchaudio.functional as F | |
from nemo.utils import logging | |
@torch.jit.script | |
def normalize_batch(x: torch.Tensor, seq_len: torch.Tensor, normalize_type: str): | |
eps = 1e-5 | |
if normalize_type == "per_feature": | |
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) | |
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) | |
for i in range(x.shape[0]): | |
if x[i, :, : seq_len[i]].shape[1] == 1: | |
raise ValueError( | |
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " | |
"in torch.std() returning nan" | |
) | |
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) | |
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) | |
# make sure x_std is not zero | |
x_std += eps | |
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) | |
elif normalize_type == "all_features": | |
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) | |
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) | |
for i in range(x.shape[0]): | |
x_mean[i] = x[i, :, : seq_len[i].item()].mean() | |
x_std[i] = x[i, :, : seq_len[i].item()].std() | |
# make sure x_std is not zero | |
x_std += eps | |
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) | |
# elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: | |
# x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) | |
# x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) | |
# return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2) | |
else: | |
return x | |
@torch.jit.script | |
def splice_frames(x: torch.Tensor, frame_splicing: int): | |
""" Stacks frames together across feature dim | |
input is batch_size, feature_dim, num_frames | |
output is batch_size, feature_dim*frame_splicing, num_frames | |
""" | |
seq = [x] | |
for n in range(1, frame_splicing): | |
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) | |
return torch.cat(seq, dim=1) | |
class FilterbankFeatures(nn.Module): | |
"""Featurizer that converts wavs to Mel Spectrograms. | |
See AudioToMelSpectrogramPreprocessor for args. | |
""" | |
def __init__( | |
self, | |
sample_rate=16000, | |
n_window_size=320, | |
n_window_stride=160, | |
window="hann", | |
normalize="per_feature", | |
n_fft=None, | |
preemph=0.97, | |
nfilt=64, | |
lowfreq=0, | |
highfreq=None, | |
log=True, | |
log_zero_guard_type="add", | |
log_zero_guard_value=2 ** -24, | |
pad_to=16, | |
max_duration=16.7, | |
frame_splicing=1, | |
exact_pad=False, | |
pad_value=0, | |
mag_power=2.0, | |
use_grads=False, | |
constant = 1e-5, | |
): | |
super().__init__() | |
if exact_pad and n_window_stride % 2 == 1: | |
raise NotImplementedError( | |
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " | |
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." | |
) | |
self.log_zero_guard_value = log_zero_guard_value | |
if ( | |
n_window_size is None | |
or n_window_stride is None | |
or not isinstance(n_window_size, int) | |
or not isinstance(n_window_stride, int) | |
or n_window_size <= 0 | |
or n_window_stride <= 0 | |
): | |
raise ValueError( | |
f"{self} got an invalid value for either n_window_size or " | |
f"n_window_stride. Both must be positive ints." | |
) | |
logging.info(f"PADDING: {pad_to}") | |
self.win_length = n_window_size | |
self.hop_length = n_window_stride | |
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) | |
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None | |
if exact_pad: | |
logging.info("STFT using exact pad") | |
torch_windows = { | |
'hann': torch.hann_window, | |
'hamming': torch.hamming_window, | |
'blackman': torch.blackman_window, | |
'bartlett': torch.bartlett_window, | |
'none': None, | |
} | |
window_fn = torch_windows.get(window, None) | |
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None | |
self.register_buffer("window", window_tensor) | |
self.exact_pad = exact_pad | |
self._constant = constant | |
self.normalize = normalize | |
self.log = log | |
self.frame_splicing = frame_splicing | |
self.nfilt = nfilt | |
self.preemph = preemph | |
self.pad_to = pad_to | |
highfreq = highfreq or sample_rate / 2 | |
# filterbanks = torch.tensor( | |
# librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), | |
# dtype=torch.float, | |
# ).unsqueeze(0) | |
filterbanks = F.melscale_fbanks( | |
sample_rate=sample_rate, n_freqs=int(self.n_fft // 2 + 1), n_mels=nfilt, f_min=lowfreq, f_max=highfreq, | |
).T.unsqueeze(0) | |
self.register_buffer("fb", filterbanks) | |
# Calculate maximum sequence length | |
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) | |
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 | |
self.max_length = max_length + max_pad | |
self.pad_value = pad_value | |
self.mag_power = mag_power | |
# We want to avoid taking the log of zero | |
# There are two options: either adding or clamping to a small value | |
if log_zero_guard_type not in ["add", "clamp"]: | |
raise ValueError( | |
f"{self} received {log_zero_guard_type} for the " | |
f"log_zero_guard_type parameter. It must be either 'add' or " | |
f"'clamp'." | |
) | |
# log_zero_guard_value is the the small we want to use, we support | |
# an actual number, or "tiny", or "eps" | |
self.log_zero_guard_type = log_zero_guard_type | |
logging.debug(f"sr: {sample_rate}") | |
logging.debug(f"n_fft: {self.n_fft}") | |
logging.debug(f"win_length: {self.win_length}") | |
logging.debug(f"hop_length: {self.hop_length}") | |
logging.debug(f"n_mels: {nfilt}") | |
logging.debug(f"fmin: {lowfreq}") | |
logging.debug(f"fmax: {highfreq}") | |
logging.debug(f"using grads: {use_grads}") | |
def log_zero_guard_value_fn(self, x): | |
if isinstance(self.log_zero_guard_value, str): | |
if self.log_zero_guard_value == "tiny": | |
return torch.finfo(x.dtype).tiny | |
elif self.log_zero_guard_value == "eps": | |
return torch.finfo(x.dtype).eps | |
else: | |
raise ValueError( | |
f"{self} received {self.log_zero_guard_value} for the " | |
f"log_zero_guard_type parameter. It must be either a " | |
f"number, 'tiny', or 'eps'" | |
) | |
else: | |
return self.log_zero_guard_value | |
def get_seq_len(self, seq_len): | |
# Assuming that center is True is stft_pad_amount = 0 | |
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 | |
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1 | |
return seq_len.to(dtype=torch.long) | |
@property | |
def filter_banks(self): | |
return self.fb | |
def normalize_batch(self, x: torch.Tensor, seq_len: torch.Tensor, normalize_type: str): | |
if normalize_type == "per_feature": | |
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) | |
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) | |
for i in range(x.shape[0]): | |
if x[i, :, : seq_len[i]].shape[1] == 1: | |
raise ValueError( | |
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " | |
"in torch.std() returning nan" | |
) | |
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) | |
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) | |
# make sure x_std is not zero | |
# x_std += CONSTANT | |
x_std += self._constant | |
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) | |
elif normalize_type == "all_features": | |
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) | |
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) | |
for i in range(x.shape[0]): | |
x_mean[i] = x[i, :, : seq_len[i].item()].mean() | |
x_std[i] = x[i, :, : seq_len[i].item()].std() | |
# make sure x_std is not zero | |
# x_std += CONSTANT | |
x_std += self._constant | |
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) | |
# elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: | |
# x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) | |
# x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) | |
# return (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2) | |
else: | |
return x | |
def splice_frames(self, x: torch.Tensor, frame_splicing: int): | |
""" Stacks frames together across feature dim | |
input is batch_size, feature_dim, num_frames | |
output is batch_size, feature_dim*frame_splicing, num_frames | |
""" | |
seq = [x] | |
for n in range(1, frame_splicing): | |
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) | |
return torch.cat(seq, dim=1) | |
def forward(self, x, seq_len): | |
seq_len = self.get_seq_len(seq_len.float()) | |
if self.stft_pad_amount is not None: | |
x = torch.nn.functional.pad( | |
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" | |
).squeeze(1) | |
# do preemphasis | |
if self.preemph is not None: | |
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) | |
# disable autocast to get full range of stft values | |
# with torch.cuda.amp.autocast(enabled=False): | |
# x = self.stft(x) | |
## MSIS: the above autocast messes up the mobile export. | |
x = torch.stft( | |
x, | |
n_fft=self.n_fft, | |
hop_length=self.hop_length, | |
win_length=self.win_length, | |
center=False if self.exact_pad else True, | |
window=self.window.to(dtype=torch.float), | |
return_complex=False, | |
) | |
# torch returns real, imag; so convert to magnitude | |
# guard is needed for sqrt if grads are passed through | |
# guard = 0 if not self.use_grads else CONSTANT | |
if x.dtype in [torch.cfloat, torch.cdouble]: | |
x = torch.view_as_real(x) | |
x = torch.sqrt(x.pow(2).sum(-1)) | |
# get power spectrum | |
if self.mag_power != 1.0: | |
x = x.pow(self.mag_power) | |
# dot with filterbank energies | |
x = torch.matmul(self.fb.to(x.dtype), x) | |
# log features if required | |
if self.log: | |
if self.log_zero_guard_type == "add": | |
x = torch.log(x + self.log_zero_guard_value_fn(x)) | |
elif self.log_zero_guard_type == "clamp": | |
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) | |
# else: | |
# raise ValueError("log_zero_guard_type was not understood") | |
# frame splicing if required | |
if self.frame_splicing > 1: | |
# x = splice_frames(x, self.frame_splicing) | |
x = self.splice_frames(x, self.frame_splicing) | |
# normalize if required | |
if self.normalize: | |
# x = normalize_batch(x, seq_len, normalize_type=self.normalize) | |
x = self.normalize_batch(x, seq_len, normalize_type=self.normalize) | |
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency) | |
max_len = x.size(-1) | |
mask = torch.arange(max_len).to(x.device) | |
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1) | |
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) | |
del mask | |
pad_to = self.pad_to | |
# if pad_to == "max": | |
# x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) | |
# elif pad_to > 0: | |
# pad_amt = x.size(-1) % pad_to | |
# if pad_amt != 0: | |
# x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) | |
if pad_to > 0: | |
pad_amt = x.size(-1) % pad_to | |
if pad_amt != 0: | |
# x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) | |
x = torch.nn.functional.pad(x, (0, pad_to - pad_amt), value=float(self.pad_value)) | |
return x, seq_len | |
class ModelWrapper2(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
mask = torch.randn((1, 80, 100)) | |
x.masked_fill_(mask, 0.0) | |
return x | |
class ModelWrapper(torch.nn.Module): | |
def __init__(self, exported_model): | |
super().__init__() | |
self.encoder = exported_model | |
self.sample_rate = 16000 | |
self.featurizer = FilterbankFeatures( | |
sample_rate=self.sample_rate, | |
nfilt=80, | |
n_fft=512, | |
pad_to=60, | |
normalize='per_feature', | |
n_window_size=400, | |
n_window_stride=160, | |
window='hann', | |
frame_splicing=1, | |
) | |
def forward(self, waveform: torch.Tensor, sample_rate: int): | |
if waveform.size(0) != 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
if sample_rate != self.sample_rate: | |
waveform = torchaudio.functional.resample(waveform, sample_rate, self.sample_rate) | |
length = torch.tensor([waveform.shape[1]]) | |
waveform, length = self.featurizer(waveform, length) | |
hypothesis = self.encoder(waveform, length) | |
return hypothesis | |
if __name__ == "__main__": | |
model = EncDecCTCModelBPE.from_pretrained('stt_en_citrinet_256', map_location='cpu') | |
model = model.eval() | |
model.export(f"/tmp/{model._get_name()}.ts", check_trace=True) | |
scripted_encoder = torch.jit.load(f"/tmp/{model._get_name()}.ts") | |
wrapped_model = ModelWrapper(scripted_encoder) | |
scripted_model = torch.jit.script(wrapped_model) | |
quantized_model = quantize_dynamic( | |
scripted_model, | |
qconfig_spec={torch.nn.Linear}, | |
dtype=torch.qint8 | |
) | |
optimized_model = optimize_for_mobile(quantized_model, backend="metal") | |
optimized_model._save_for_lite_interpreter("/tmp/mymodel-metal.ts") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment