MockingBird/vits.ipynb
2023-02-18 09:31:05 +08:00

547 lines
19 KiB
Plaintext
Vendored
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from utils.hparams import load_hparams_json\n",
"from utils.util import intersperse\n",
"import json\n",
"from models.synthesizer.models.vits import Vits\n",
"import torch\n",
"import numpy as np\n",
"import IPython.display as ipd\n",
"from models.synthesizer.utils.symbols import symbols\n",
"from models.synthesizer.utils.text import text_to_sequence\n",
"\n",
"\n",
"hps = load_hparams_json(\"data/ckpt/synthesizer/vits5/config.json\")\n",
"print(hps.train)\n",
"model = Vits(\n",
" len(symbols),\n",
" hps[\"data\"][\"filter_length\"] // 2 + 1,\n",
" hps[\"train\"][\"segment_size\"] // hps[\"data\"][\"hop_length\"],\n",
" n_speakers=hps[\"data\"][\"n_speakers\"],\n",
" **hps[\"model\"])\n",
"_ = model.eval()\n",
"device = torch.device(\"cpu\")\n",
"checkpoint = torch.load(str(\"data/ckpt/synthesizer/vits5/G_56000.pth\"), map_location=device)\n",
"if \"model_state\" in checkpoint:\n",
" state = checkpoint[\"model_state\"]\n",
"else:\n",
" state = checkpoint[\"model\"]\n",
"model.load_state_dict(state, strict=False)\n",
"\n",
"# 随机抽取情感参考音频的根目录\n",
"random_emotion_root = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
"import random, re\n",
"from pypinyin import lazy_pinyin, Style\n",
"\n",
"import os\n",
"\n",
"def tts(txt, emotion, sid=0):\n",
" txt = \" \".join(lazy_pinyin(txt, style=Style.TONE3, neutral_tone_with_five=False))\n",
" text_norm = text_to_sequence(txt, hps[\"data\"][\"text_cleaners\"])\n",
" # if hps[\"data\"][\"add_blank\"]:\n",
" # text_norm = intersperse(text_norm, 0)\n",
" stn_tst = torch.LongTensor(text_norm)\n",
"\n",
" with torch.no_grad(): #inference mode\n",
" x_tst = stn_tst.unsqueeze(0)\n",
" x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n",
" sid = torch.LongTensor([sid])\n",
" if emotion.endswith(\"wav\"):\n",
" from models.synthesizer.preprocess_audio import extract_emo\n",
" import librosa\n",
" wav, sr = librosa.load(emotion, 16000)\n",
" emo = torch.FloatTensor(extract_emo(np.expand_dims(wav, 0), sr, embeddings=True))\n",
" elif emotion == \"random_sample\":\n",
" rand_emo = random.sample(os.listdir(random_emotion_root), 1)[0]\n",
" print(rand_emo)\n",
" emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{rand_emo}\")).unsqueeze(0)\n",
" elif emotion.endswith(\"npy\"):\n",
" print(emotion)\n",
" emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{emotion}\")).unsqueeze(0)\n",
" else:\n",
" print(\"emotion参数不正确\")\n",
"\n",
" audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1, emo=emo)[0][0,0].data.float().numpy()\n",
" ipd.display(ipd.Audio(audio, rate=hps[\"data\"][\"sampling_rate\"], normalize=False))\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"推理:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
"#正常: \n",
"tts(txt, emotion='emo-T0055G4906S0052.wav_00.npy', sid=100)\n",
"#快速emo-T0055G2323S0179.wav_00.npy\n",
"\n",
"#难过:\n",
"tts(txt, emotion='emo-15_4581_20170825202626.wav_00.npy', sid=100)\n",
"\n",
"#开心T0055G2412S0498.wav\n",
"tts(txt, emotion='emo-T0055G2412S0498.wav_00.npy', sid=100)\n",
"\n",
"#愤怒 T0055G1371S0363.wav T0055G1344S0160.wav\n",
"tts(txt, emotion='emo-T0055G1344S0160.wav_00.npy', sid=100)\n",
"\n",
"#疲惫\n",
"tts(txt, emotion='emo-T0055G2294S0476.wav_00.npy', sid=100)\n",
"\n",
"#着急\n",
"tts(txt, emotion='emo-T0055G1671S0170.wav_00.npy', sid=100)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
"types = [\"平淡\", \"激动\", \"疲惫\", \"兴奋\", \"沮丧\", \"开心\"]\n",
"for t in types:\n",
" print(t)\n",
" tts(txt, emotion=f'C:\\\\Users\\\\babys\\\\Music\\\\{t}.wav', sid=100)\n",
"# tts(txt, emotion='D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav', sid=5)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"预处理:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from models.synthesizer.preprocess import preprocess_dataset\n",
"from pathlib import Path\n",
"from utils.hparams import HParams\n",
"datasets_root = Path(\"../audiodata/\")\n",
"hparams = HParams(\n",
" n_fft = 1024, # filter_length\n",
" num_mels = 80,\n",
" hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)\n",
" win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)\n",
" fmin = 55,\n",
" min_level_db = -100,\n",
" ref_level_db = 20,\n",
" max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.\n",
" sample_rate = 16000,\n",
" rescale = True,\n",
" max_mel_frames = 900,\n",
" rescaling_max = 0.9, \n",
" preemphasis = 0.97, # Filter coefficient to use if preemphasize is True\n",
" preemphasize = True,\n",
" ### Mel Visualization and Griffin-Lim\n",
" signal_normalization = True,\n",
"\n",
" utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded\n",
" ### Audio processing options\n",
" fmax = 7600, # Should not exceed (sample_rate // 2)\n",
" allow_clipping_in_normalization = True, # Used when signal_normalization = True\n",
" clip_mels_length = True, # If true, discards samples exceeding max_mel_frames\n",
" use_lws = False, # \"Fast spectrogram phase recovery using local weighted sums\"\n",
" symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,\n",
" # and [0, max_abs_value] if False\n",
" trim_silence = False, # Use with sample_rate of 16000 for best results\n",
"\n",
")\n",
"preprocess_dataset(datasets_root=datasets_root, \n",
" out_dir=datasets_root.joinpath(\"SV2TTS\", \"synthesizer\"),\n",
" n_processes=8,\n",
" skip_existing=True, \n",
" hparams=hparams, \n",
" no_alignments=False, \n",
" dataset=\"aidatatang_200zh\", \n",
" emotion_extract=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"训练:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from models.synthesizer.train_vits import run\n",
"from pathlib import Path\n",
"from utils.hparams import HParams\n",
"import torch, os\n",
"import torch.multiprocessing as mp\n",
"\n",
"datasets_root = Path(\"../audiodata/SV2TTS/synthesizer\")\n",
"hparams= HParams(\n",
" model_dir = \"data/ckpt/synthesizer/vits\",\n",
")\n",
"hparams.loadJson(Path(hparams.model_dir).joinpath(\"config.json\"))\n",
"hparams.data[\"training_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
"hparams.data[\"validation_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
"hparams.data[\"datasets_root\"] = str(datasets_root)\n",
"\n",
"n_gpus = torch.cuda.device_count()\n",
"# for spawn\n",
"os.environ['MASTER_ADDR'] = 'localhost'\n",
"os.environ['MASTER_PORT'] = '8899'\n",
"mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hparams))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"挑选只有对应emo文件的meta数据"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import os\n",
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
"dict_info = []\n",
"with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
" for raw in dict_meta:\n",
" if not raw:\n",
" continue\n",
" v = raw.split(\"|\")[0].replace(\"audio\",\"emo\")\n",
" emo_fpath = root.joinpath(\"emo\").joinpath(v)\n",
" if emo_fpath.exists():\n",
" dict_info.append(raw)\n",
" # else:\n",
" # print(emo_fpath)\n",
"# Iterate over each wav\n",
"meta2 = Path('../audiodata/SV2TTS/synthesizer/train2.txt')\n",
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
"for new_info in dict_info:\n",
" metadata_file.write(new_info)\n",
"metadata_file.close()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"从训练集中抽取10%作为测试集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
"dict_info1 = []\n",
"dict_info2 = []\n",
"count = 1\n",
"with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
" for raw in dict_meta:\n",
" if not raw:\n",
" continue\n",
" if count % 10 == 0:\n",
" dict_info2.append(raw)\n",
" else:\n",
" dict_info1.append(raw)\n",
" count += 1\n",
"# Iterate over each wav\n",
"meta1 = Path('../audiodata/SV2TTS/synthesizer/train1.txt')\n",
"metadata_file = meta1.open(\"w\", encoding=\"utf-8\")\n",
"for new_info in dict_info1:\n",
" metadata_file.write(new_info)\n",
"metadata_file.close()\n",
"\n",
"meta2 = Path('../audiodata/SV2TTS/synthesizer/eval.txt')\n",
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
"for new_info in dict_info2:\n",
" metadata_file.write(new_info)\n",
"metadata_file.close()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
"spks = []\n",
"spk_id = {}\n",
"rows = []\n",
"with open(root.joinpath(\"eval.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
" for raw in dict_meta:\n",
" speaker_name = raw.split(\"-\")[1][6:10]\n",
" if speaker_name not in spk_id:\n",
" spks.append(speaker_name)\n",
" spk_id[speaker_name] = 1\n",
" rows.append(raw)\n",
"i = 0\n",
"spks.sort()\n",
"\n",
"for sp in spks:\n",
" spk_id[sp] = str(i)\n",
" i = i + 1\n",
"print(len(spks))\n",
"meta2 = Path('../audiodata/SV2TTS/synthesizer/eval2.txt')\n",
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
"for row in rows:\n",
" speaker_n = row.split(\"-\")[1][6:10]\n",
" metadata_file.write(row.strip()+\"|\"+spk_id[speaker_n]+\"\\n\")\n",
"metadata_file.close()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[Not Recommended]\n",
"Try to transcript map to detailed format:\n",
"ni3 hao3 -> n i3 <pad> h ao3\n",
"\n",
"After couple of tests, I think this method will not improve the quality of result and may cause the crash of monotonic alignment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pathlib import Path\n",
"datasets_root = Path(\"../audiodata/SV2TTS/synthesizer/\")\n",
"\n",
"dictionary_fp = Path(\"../audiodata/ProDiff/processed/mandarin_pinyin.dict\")\n",
"dict_map = {}\n",
"for l in open(dictionary_fp, encoding='utf-8').readlines():\n",
" item = l.split(\"\\t\")\n",
" dict_map[item[0]] = item[1].replace(\"\\n\",\"\")\n",
"\n",
"with datasets_root.joinpath('train2.txt').open(\"w+\", encoding='utf-8') as f:\n",
" for l in open(datasets_root.joinpath('train.txt'), encoding='utf-8').readlines():\n",
" items = l.strip().replace(\"\\n\",\"\").replace(\"\\t\",\" \").split(\"|\")\n",
" phs_str = \"\"\n",
" for word in items[5].split(\" \"):\n",
" if word in dict_map:\n",
" phs_str += dict_map[word] \n",
" else:\n",
" phs_str += word\n",
" phs_str += \" _ \"\n",
" items[5] = phs_str\n",
" # if not os.path.exists(mfa_input_root.joinpath('train.txt')):\n",
" # with open(mfa_input_root.joinpath(fileName + 'lab'), 'w+', encoding=\"utf-8\") as f:\n",
" f.write(\"|\".join(items) + \"\\n\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"预处理后的数据可视化"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import librosa.display\n",
"import librosa, torch\n",
"import numpy as np\n",
"from utils.audio_utils import spectrogram, mel_spectrogram, load_wav_to_torch, spec_to_mel\n",
"\n",
"# x, sr = librosa.load(\"D:\\audiodata\\SV2TTS\\synthesizer\\audio\\audio-T0055G2333S0196.wav_00.npy\")\n",
"x = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\audio\\\\audio-T0055G1858S0342.wav_00.npy\")\n",
"\n",
"plt.figure(figsize=(14, 5))\n",
"librosa.display.waveplot(x)\n",
"\n",
"X = librosa.stft(x)\n",
"Xdb = librosa.amplitude_to_db(abs(X))\n",
"plt.figure(figsize=(14, 5))\n",
"librosa.display.specshow(Xdb, x_axis='time', y_axis='hz')\n",
"\n",
"# spectrogram = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\mels\\\\mel-T0055G1858S0342.wav_00.npy\")\n",
"audio = torch.from_numpy(x.astype(np.float32))\n",
"\n",
"# audio, sampling_rate = load_wav_to_torch(\"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav\")\n",
"# audio_norm = audio / 32768.0\n",
"audio_norm = audio.unsqueeze(0)\n",
"spec = spectrogram(audio_norm, 1024, 256, 1024, center=False)\n",
"# spec = spec_to_mel()\n",
"spec = torch.squeeze(spec, 0)\n",
"mel = spec_to_mel(spec, 1024, 80, 16000, 0, None)\n",
"\n",
"fig = plt.figure(figsize=(10, 8))\n",
"ax2 = fig.add_subplot(211)\n",
"im = ax2.imshow(mel, interpolation=\"none\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"情感聚类"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# from sklearn import metrics\n",
"# from sklearn.mixture import GaussianMixture # 高斯混合模型\n",
"import os\n",
"import numpy as np\n",
"import librosa\n",
"import IPython.display as ipd\n",
"from random import sample\n",
"\n",
"embs = []\n",
"wavnames = []\n",
"emo_root_path = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
"wav_root_path = \"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\\"\n",
"for idx, emo_fpath in enumerate(sample(os.listdir(emo_root_path), 10000)):\n",
" if emo_fpath.endswith(\".npy\") and emo_fpath.startswith(\"emo-T\"):\n",
" embs.append(np.expand_dims(np.load(emo_root_path + emo_fpath), axis=0))\n",
" wav_fpath = wav_root_path + emo_fpath[9:14] + \"\\\\\" + emo_fpath.split(\"_00\")[0][4:]\n",
" wavnames.append(wav_fpath)\n",
"print(len(embs))\n",
"\n",
"\n",
"x = np.concatenate(embs, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 聚类算法类的数量\n",
"n_clusters = 20\n",
"from sklearn.cluster import *\n",
"# model = KMeans(n_clusters=n_clusters, random_state=10)\n",
"# model = DBSCAN(eps=0.002, min_samples=2)\n",
"# 可以自行尝试各种不同的聚类算法\n",
"# model = Birch(n_clusters= n_clusters, threshold= 0.2)\n",
"# model = SpectralClustering(n_clusters=n_clusters)\n",
"model = AgglomerativeClustering(n_clusters= n_clusters)\n",
"import random\n",
"\n",
"y_predict = model.fit_predict(x)\n",
"\n",
"def disp(wavname):\n",
" wav, sr =librosa.load(wavname, 16000)\n",
" display(ipd.Audio(wav, rate=sr))\n",
"\n",
"classes=[[] for i in range(y_predict.max()+1)]\n",
"\n",
"for idx, wavname in enumerate(wavnames):\n",
" classes[y_predict[idx]].append(wavname)\n",
"\n",
"for i in range(y_predict.max()+1):\n",
" print(\"类别:\", i, \"本类中样本数量:\", len(classes[i]))\n",
" \"\"\"每一个类只预览2条音频\"\"\"\n",
" for j in range(2):\n",
" idx = random.randint(0, len(classes[i]) - 1)\n",
" print(classes[i][idx])\n",
" disp(classes[i][idx])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mo",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "788ab866da3baa6c99886d56abb59fe71b6a552bf52c65473ecf96c784704db8"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}