"""
使用Modelscope开源美式英文TTS生成语音数据,仅供学习分享交流之用,如有侵权,请联系我删除
Use Modelscope's OpenSource TTS to generate English(en-us) speech data, for learning and sharing only, if there is any infringement, please contact me to delete it

下载地址 Download link
https://pan.baidu.com/s/1qUjBhCVknOTV-xm4VBEuDQ?pwd=uqbd 

数据示例data example: 
annie|annie_LJ001-0002_0.wav|IH0 N #1 B IY1 IH0 NG #1 K AH0 M P EH1 R AH0 T IH0 V L IY0 #1 M AA1 D ER0 N #4|in being comparatively modern.

约40小时,16khz单声道,29912句,一男一女,美式英语
About 40 hours, 16khz mono, 29912 sentences

优点是带了精确的CMU标注,缺点是阿里开源的TTS生成的语音音质不太好,有杂音
The advantage is that it has accurate CMU annotations, but the sound quality of the TTS generated by Ali open source is not very good, with noise

来源Source: 
- text: ljspeech
- https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_andy_en-us_16k/summary
- https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_annie_en-us_16k/summary

Usage: 
pip install "modelscope[audio]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
pip install --find-links https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html ttsfrd
CUDA_VISIBLE_DEVICES=0 python local/generate_ali_en_tts_data.py --outdir ${your_dir} --file LJSpeech-1.1-metadata.csv --spk annie
CUDA_VISIBLE_DEVICES=0 python local/generate_ali_en_tts_data.py --outdir ${your_dir} --file LJSpeech-1.1-metadata.csv --spk andy
"""

import re
import json
from pathlib import Path
from scipy.io.wavfile import write
from tqdm import tqdm

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

# from modelscope.models.audio.tts.sambert_hifi import SambertHifigan
# tts_engine = SambertHifigan(model_dir, am='xx', vocoder = 'xx', lang_type='xx')


def format_ttsfrd_output(res, text):
    phone_list = []
    for sen in res.split("\n"):
        if sen:
            sen_idx, phoneme_infos = sen.split("\t")
            for info_str in phoneme_infos.split():
                # info_str example: '{h_c$tone3$s_begin$word_middle$emotion_neutral$F7}'
                infos = info_str[1:-1].split("$")
                phone = infos[0]
                tone = infos[1].replace("tone", "")

                if phone.startswith("#"):
                    pass
                elif "_c" in phone:  # 中文拼音,以便与英文 cmudict 相区分
                    if re.search(r"[aeiouvr]", phone):
                        if not re.match(r"[0-5]", tone):
                            print(f"非法中文声调 {text} error: {phone} {tone}")
                        phone += tone
                elif re.search(r"g[a-z]", phone):
                    # 单元音的辅音不加声调,例如:
                    # 比如: 嗯: 有 ge$tone_5 en_c$tone_5,还有很多类似的 go ga 等
                    pass
                elif re.search(r"[a-z]", phone):  # 英文音素
                    if re.search(r"[aeiou]", phone):  # 元音加声调
                        if not re.match(r"[0-2]", tone):
                            print(f"非法英文声调 {text} error: {phone} {tone}")
                        phone += tone
                else:
                    print(f"{text} 含有未知音素 error: {info_str}, {phone}, {tone}")
                phone_list.append(phone)
    return " ".join(phone_list)


CMU_PHONEMES = [
    "AA0",
    "AA1",
    "AA2",
    "AE0",
    "AE1",
    "AE2",
    "AH0",
    "AH1",
    "AH2",
    "AO0",
    "AO1",
    "AO2",
    "AW0",
    "AW1",
    "AW2",
    "AY0",
    "AY1",
    "AY2",
    "B",
    "CH",
    "D",
    "DH",
    "EH0",
    "EH1",
    "EH2",
    "ER0",
    "ER1",
    "ER2",
    "EY0",
    "EY1",
    "EY2",
    "F",
    "G",
    "HH",
    "IH0",
    "IH1",
    "IH2",
    "IY0",
    "IY1",
    "IY2",
    "JH",
    "K",
    "L",
    "M",
    "N",
    "NG",
    "OW0",
    "OW1",
    "OW2",
    "OY0",
    "OY1",
    "OY2",
    "P",
    "R",
    "S",
    "SH",
    "T",
    "TH",
    "UH0",
    "UH1",
    "UH2",
    "UW",
    "UW0",
    "UW1",
    "UW2",
    "V",
    "W",
    "Y",
    "Z",
    "ZH",
]
SET_CMU_PHONEMES = set(CMU_PHONEMES)

spk_model_id = {
    "annie": "damo/speech_sambert-hifigan_tts_annie_en-us_16k",
    "andy": "damo/speech_sambert-hifigan_tts_andy_en-us_16k",
}


def clean_text(text: str):
    """删除不表示停顿的 句号"""
    text = text.replace("i.e. ", "IE ")
    text = text.replace("Mr. ", "Mr ")
    text = text.replace("Mrs. ", "Mrs ")
    text = text.replace("Ms. ", "Ms ")
    text = text.replace("Dr. ", "Dr ")
    text = text.replace("Prof. ", "Prof ")
    text = text.replace("St. ", "St ")
    text = text.replace("Sr. ", "Sr ")
    text = text.replace("Jr. ", "Jr ")
    text = text.replace("Maj. ", "Maj ")
    text = text.replace("Gen. ", "Gen ")
    text = text.replace("Col. ", "Col ")
    text = text.replace("Lt. ", "Lt ")
    text = text.replace("Capt. ", "Capt ")
    text = text.replace("Hon. ", "Hon ")
    text = text.replace("Sen. ", "Sen ")
    text = text.replace("Rep. ", "Rep ")
    text = text.replace("Gov. ", "Gov ")
    text = text.replace("U.S.", "US")
    return text


def split_by_punc(text: str, puncs: str = ".!?;"):
    for punc in puncs:
        text = text.replace(punc, punc + "▁")
    return text.split("▁")


def main(args):
    wav_dir = args.outdir / "wavs_16k"
    wav_dir.mkdir(parents=True, exist_ok=True)

    model_id = spk_model_id[args.spk]
    spk = args.spk

    total_lines = []
    with open(args.file, "r", encoding="utf-8") as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            wav_id, ori_text, ori_text = line.strip().split("|")

            # 按大标点符号分割,因为这里的 tts_engine 拼接句子时中间没有加入停顿
            sub_texts = split_by_punc(clean_text(ori_text))
            sub_texts = [x.strip() for x in sub_texts if x.strip()]

            for j, sub_text in enumerate(sub_texts):
                wav_path = str(wav_dir / f"{spk}_{wav_id}_{j}.wav")
                total_lines.append([wav_path, sub_text])

    tts_engine = pipeline(task=Tasks.text_to_speech, model=model_id)
    frontend = tts_engine.model._SambertHifigan__frontend

    # output = tts_engine(input="hello word text to synthesis")
    # pcm = output[OutputKeys.OUTPUT_PCM]
    # write('test.wav', 16000, pcm)
    # print('Write test.wav')

    ling_info = {}
    data = []

    with open(args.outdir / f"{spk}_ali_en_us_16k.csv", "w", encoding="utf8") as wf:
        for (wav_path, text) in tqdm(total_lines):
            result = tts_engine(input=text)
            wav = result[OutputKeys.OUTPUT_PCM]
            write(wav_path, 16000, wav)

            linguistic_info = frontend.gen_tacotron_symbols(text)
            ling_info[wav_path] = {}
            ling_info[wav_path]["info"] = linguistic_info
            ling_info[wav_path]["text"] = text

            phonemes = format_ttsfrd_output(linguistic_info, text).upper()

            item = [spk, wav_path, phonemes, text]
            data.append(item)
            wf.write("|".join(item) + "\n")

    with open(args.outdir / f"{spk}_ling_info.json", "w", encoding="utf-8") as wf:
        json.dump(ling_info, wf, indent=4, ensure_ascii=False)

    # check the result
    for line in data:
        spk, _, _, wav_path, phonemes, text = line
        for phone in phonemes.split():
            if phone.startswith("#"):
                continue
            if phone not in SET_CMU_PHONEMES:
                print(f"{phone} not in CMU_PHONEMES: {text=}")
                break


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument("--file", default=None)
    parser.add_argument("--outdir", default="./temp", type=Path)
    parser.add_argument("--spk", default="annie", type=str)

    args = parser.parse_args()
    main(args)