MockingBird/vits.ipynb

547 lines
19 KiB
Plaintext
Raw Normal View History

2023-02-04 14:13:38 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
2023-02-04 14:13:38 +08:00
"metadata": {},
"outputs": [],
2023-02-04 14:13:38 +08:00
"source": [
"from utils.hparams import load_hparams_json\n",
"from utils.util import intersperse\n",
"import json\n",
"from models.synthesizer.models.vits import Vits\n",
"import torch\n",
"import numpy as np\n",
"import IPython.display as ipd\n",
"from models.synthesizer.utils.symbols import symbols\n",
"from models.synthesizer.utils.text import text_to_sequence\n",
2023-02-04 14:13:38 +08:00
"\n",
"\n",
2023-02-18 09:31:05 +08:00
"hps = load_hparams_json(\"data/ckpt/synthesizer/vits5/config.json\")\n",
2023-02-04 14:13:38 +08:00
"print(hps.train)\n",
"model = Vits(\n",
" len(symbols),\n",
" hps[\"data\"][\"filter_length\"] // 2 + 1,\n",
" hps[\"train\"][\"segment_size\"] // hps[\"data\"][\"hop_length\"],\n",
" n_speakers=hps[\"data\"][\"n_speakers\"],\n",
" **hps[\"model\"])\n",
"_ = model.eval()\n",
"device = torch.device(\"cpu\")\n",
2023-02-18 09:31:05 +08:00
"checkpoint = torch.load(str(\"data/ckpt/synthesizer/vits5/G_56000.pth\"), map_location=device)\n",
"if \"model_state\" in checkpoint:\n",
" state = checkpoint[\"model_state\"]\n",
"else:\n",
" state = checkpoint[\"model\"]\n",
"model.load_state_dict(state, strict=False)\n",
2023-02-04 14:13:38 +08:00
"\n",
"# 随机抽取情感参考音频的根目录\n",
2023-02-18 09:31:05 +08:00
"random_emotion_root = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
2023-02-04 14:13:38 +08:00
"import random, re\n",
"from pypinyin import lazy_pinyin, Style\n",
2023-02-04 14:13:38 +08:00
"\n",
"import os\n",
"\n",
"def tts(txt, emotion, sid=0):\n",
2023-02-18 09:31:05 +08:00
" txt = \" \".join(lazy_pinyin(txt, style=Style.TONE3, neutral_tone_with_five=False))\n",
2023-02-04 14:13:38 +08:00
" text_norm = text_to_sequence(txt, hps[\"data\"][\"text_cleaners\"])\n",
2023-02-18 09:31:05 +08:00
" # if hps[\"data\"][\"add_blank\"]:\n",
" # text_norm = intersperse(text_norm, 0)\n",
2023-02-04 14:13:38 +08:00
" stn_tst = torch.LongTensor(text_norm)\n",
"\n",
" with torch.no_grad(): #inference mode\n",
" x_tst = stn_tst.unsqueeze(0)\n",
" x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n",
" sid = torch.LongTensor([sid])\n",
" if emotion.endswith(\"wav\"):\n",
" from models.synthesizer.preprocess_audio import extract_emo\n",
" import librosa\n",
" wav, sr = librosa.load(emotion, 16000)\n",
" emo = torch.FloatTensor(extract_emo(np.expand_dims(wav, 0), sr, embeddings=True))\n",
2023-02-18 09:31:05 +08:00
" elif emotion == \"random_sample\":\n",
" rand_emo = random.sample(os.listdir(random_emotion_root), 1)[0]\n",
" print(rand_emo)\n",
" emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{rand_emo}\")).unsqueeze(0)\n",
" elif emotion.endswith(\"npy\"):\n",
" print(emotion)\n",
" emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{emotion}\")).unsqueeze(0)\n",
2023-02-04 14:13:38 +08:00
" else:\n",
" print(\"emotion参数不正确\")\n",
"\n",
" audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1, emo=emo)[0][0,0].data.float().numpy()\n",
" ipd.display(ipd.Audio(audio, rate=hps[\"data\"][\"sampling_rate\"], normalize=False))\n",
"\n",
"\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"推理:"
]
},
{
"cell_type": "code",
"execution_count": null,
2023-02-04 14:13:38 +08:00
"metadata": {},
"outputs": [],
2023-02-04 14:13:38 +08:00
"source": [
2023-02-18 09:31:05 +08:00
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
"#正常: \n",
"tts(txt, emotion='emo-T0055G4906S0052.wav_00.npy', sid=100)\n",
"#快速emo-T0055G2323S0179.wav_00.npy\n",
"\n",
"#难过:\n",
"tts(txt, emotion='emo-15_4581_20170825202626.wav_00.npy', sid=100)\n",
"\n",
"#开心T0055G2412S0498.wav\n",
"tts(txt, emotion='emo-T0055G2412S0498.wav_00.npy', sid=100)\n",
"\n",
"#愤怒 T0055G1371S0363.wav T0055G1344S0160.wav\n",
"tts(txt, emotion='emo-T0055G1344S0160.wav_00.npy', sid=100)\n",
"\n",
"#疲惫\n",
"tts(txt, emotion='emo-T0055G2294S0476.wav_00.npy', sid=100)\n",
"\n",
"#着急\n",
"tts(txt, emotion='emo-T0055G1671S0170.wav_00.npy', sid=100)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)\n",
"tts(txt, emotion='random_sample', sid=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
"types = [\"平淡\", \"激动\", \"疲惫\", \"兴奋\", \"沮丧\", \"开心\"]\n",
"for t in types:\n",
" print(t)\n",
" tts(txt, emotion=f'C:\\\\Users\\\\babys\\\\Music\\\\{t}.wav', sid=100)\n",
"# tts(txt, emotion='D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav', sid=5)"
2023-02-04 14:13:38 +08:00
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"预处理:"
]
},
{
"cell_type": "code",
"execution_count": null,
2023-02-04 14:13:38 +08:00
"metadata": {},
"outputs": [],
2023-02-04 14:13:38 +08:00
"source": [
"from models.synthesizer.preprocess import preprocess_dataset\n",
"from pathlib import Path\n",
"from utils.hparams import HParams\n",
"datasets_root = Path(\"../audiodata/\")\n",
"hparams = HParams(\n",
" n_fft = 1024, # filter_length\n",
" num_mels = 80,\n",
" hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)\n",
" win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)\n",
" fmin = 55,\n",
" min_level_db = -100,\n",
" ref_level_db = 20,\n",
" max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.\n",
2023-02-04 14:13:38 +08:00
" sample_rate = 16000,\n",
" rescale = True,\n",
" max_mel_frames = 900,\n",
" rescaling_max = 0.9, \n",
" preemphasis = 0.97, # Filter coefficient to use if preemphasize is True\n",
" preemphasize = True,\n",
" ### Mel Visualization and Griffin-Lim\n",
" signal_normalization = True,\n",
2023-02-04 14:13:38 +08:00
"\n",
" utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded\n",
" ### Audio processing options\n",
" fmax = 7600, # Should not exceed (sample_rate // 2)\n",
" allow_clipping_in_normalization = True, # Used when signal_normalization = True\n",
" clip_mels_length = True, # If true, discards samples exceeding max_mel_frames\n",
" use_lws = False, # \"Fast spectrogram phase recovery using local weighted sums\"\n",
" symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,\n",
" # and [0, max_abs_value] if False\n",
2023-02-18 09:31:05 +08:00
" trim_silence = False, # Use with sample_rate of 16000 for best results\n",
2023-02-04 14:13:38 +08:00
"\n",
")\n",
"preprocess_dataset(datasets_root=datasets_root, \n",
" out_dir=datasets_root.joinpath(\"SV2TTS\", \"synthesizer\"),\n",
" n_processes=8,\n",
" skip_existing=True, \n",
" hparams=hparams, \n",
" no_alignments=False, \n",
2023-02-18 09:31:05 +08:00
" dataset=\"aidatatang_200zh\", \n",
2023-02-04 14:13:38 +08:00
" emotion_extract=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"训练:"
]
},
{
"cell_type": "code",
"execution_count": null,
2023-02-04 14:13:38 +08:00
"metadata": {},
"outputs": [],
2023-02-04 14:13:38 +08:00
"source": [
"from models.synthesizer.train_vits import run\n",
"from pathlib import Path\n",
"from utils.hparams import HParams\n",
"import torch, os\n",
"import torch.multiprocessing as mp\n",
"\n",
"datasets_root = Path(\"../audiodata/SV2TTS/synthesizer\")\n",
"hparams= HParams(\n",
" model_dir = \"data/ckpt/synthesizer/vits\",\n",
")\n",
"hparams.loadJson(Path(hparams.model_dir).joinpath(\"config.json\"))\n",
"hparams.data[\"training_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
"hparams.data[\"validation_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
"hparams.data[\"datasets_root\"] = str(datasets_root)\n",
"\n",
"n_gpus = torch.cuda.device_count()\n",
"# for spawn\n",
"os.environ['MASTER_ADDR'] = 'localhost'\n",
"os.environ['MASTER_PORT'] = '8899'\n",
"mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hparams))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"挑选只有对应emo文件的meta数据"
]
},
{
"cell_type": "code",
"execution_count": null,
2023-02-04 14:13:38 +08:00
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import os\n",
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
"dict_info = []\n",
"with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
" for raw in dict_meta:\n",
" if not raw:\n",
" continue\n",
" v = raw.split(\"|\")[0].replace(\"audio\",\"emo\")\n",
" emo_fpath = root.joinpath(\"emo\").joinpath(v)\n",
2023-02-04 17:00:49 +08:00
" if emo_fpath.exists():\n",
" dict_info.append(raw)\n",
" # else:\n",
" # print(emo_fpath)\n",
"# Iterate over each wav\n",
"meta2 = Path('../audiodata/SV2TTS/synthesizer/train2.txt')\n",
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
"for new_info in dict_info:\n",
" metadata_file.write(new_info)\n",
"metadata_file.close()"
]
},
2023-02-18 09:31:05 +08:00
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"从训练集中抽取10%作为测试集"
]
},
2023-02-04 17:00:49 +08:00
{
"cell_type": "code",
"execution_count": null,
2023-02-04 17:00:49 +08:00
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
2023-02-18 09:31:05 +08:00
"dict_info1 = []\n",
"dict_info2 = []\n",
"count = 1\n",
2023-02-04 17:00:49 +08:00
"with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
" for raw in dict_meta:\n",
" if not raw:\n",
" continue\n",
2023-02-18 09:31:05 +08:00
" if count % 10 == 0:\n",
" dict_info2.append(raw)\n",
" else:\n",
" dict_info1.append(raw)\n",
" count += 1\n",
2023-02-04 14:13:38 +08:00
"# Iterate over each wav\n",
2023-02-18 09:31:05 +08:00
"meta1 = Path('../audiodata/SV2TTS/synthesizer/train1.txt')\n",
"metadata_file = meta1.open(\"w\", encoding=\"utf-8\")\n",
"for new_info in dict_info1:\n",
" metadata_file.write(new_info)\n",
"metadata_file.close()\n",
"\n",
"meta2 = Path('../audiodata/SV2TTS/synthesizer/eval.txt')\n",
2023-02-04 14:13:38 +08:00
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
2023-02-18 09:31:05 +08:00
"for new_info in dict_info2:\n",
2023-02-04 14:13:38 +08:00
" metadata_file.write(new_info)\n",
"metadata_file.close()"
]
2023-02-18 09:31:05 +08:00
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
"spks = []\n",
"spk_id = {}\n",
"rows = []\n",
"with open(root.joinpath(\"eval.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
" for raw in dict_meta:\n",
" speaker_name = raw.split(\"-\")[1][6:10]\n",
" if speaker_name not in spk_id:\n",
" spks.append(speaker_name)\n",
" spk_id[speaker_name] = 1\n",
" rows.append(raw)\n",
"i = 0\n",
"spks.sort()\n",
"\n",
"for sp in spks:\n",
" spk_id[sp] = str(i)\n",
" i = i + 1\n",
"print(len(spks))\n",
"meta2 = Path('../audiodata/SV2TTS/synthesizer/eval2.txt')\n",
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
"for row in rows:\n",
" speaker_n = row.split(\"-\")[1][6:10]\n",
" metadata_file.write(row.strip()+\"|\"+spk_id[speaker_n]+\"\\n\")\n",
"metadata_file.close()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"[Not Recommended]\n",
"Try to transcript map to detailed format:\n",
"ni3 hao3 -> n i3 <pad> h ao3\n",
"\n",
"After couple of tests, I think this method will not improve the quality of result and may cause the crash of monotonic alignment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pathlib import Path\n",
"datasets_root = Path(\"../audiodata/SV2TTS/synthesizer/\")\n",
"\n",
"dictionary_fp = Path(\"../audiodata/ProDiff/processed/mandarin_pinyin.dict\")\n",
"dict_map = {}\n",
"for l in open(dictionary_fp, encoding='utf-8').readlines():\n",
" item = l.split(\"\\t\")\n",
" dict_map[item[0]] = item[1].replace(\"\\n\",\"\")\n",
"\n",
"with datasets_root.joinpath('train2.txt').open(\"w+\", encoding='utf-8') as f:\n",
" for l in open(datasets_root.joinpath('train.txt'), encoding='utf-8').readlines():\n",
" items = l.strip().replace(\"\\n\",\"\").replace(\"\\t\",\" \").split(\"|\")\n",
" phs_str = \"\"\n",
" for word in items[5].split(\" \"):\n",
" if word in dict_map:\n",
" phs_str += dict_map[word] \n",
" else:\n",
" phs_str += word\n",
" phs_str += \" _ \"\n",
" items[5] = phs_str\n",
" # if not os.path.exists(mfa_input_root.joinpath('train.txt')):\n",
" # with open(mfa_input_root.joinpath(fileName + 'lab'), 'w+', encoding=\"utf-8\") as f:\n",
" f.write(\"|\".join(items) + \"\\n\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"预处理后的数据可视化"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import librosa.display\n",
"import librosa, torch\n",
"import numpy as np\n",
"from utils.audio_utils import spectrogram, mel_spectrogram, load_wav_to_torch, spec_to_mel\n",
"\n",
"# x, sr = librosa.load(\"D:\\audiodata\\SV2TTS\\synthesizer\\audio\\audio-T0055G2333S0196.wav_00.npy\")\n",
"x = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\audio\\\\audio-T0055G1858S0342.wav_00.npy\")\n",
"\n",
"plt.figure(figsize=(14, 5))\n",
"librosa.display.waveplot(x)\n",
"\n",
"X = librosa.stft(x)\n",
"Xdb = librosa.amplitude_to_db(abs(X))\n",
"plt.figure(figsize=(14, 5))\n",
"librosa.display.specshow(Xdb, x_axis='time', y_axis='hz')\n",
"\n",
"# spectrogram = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\mels\\\\mel-T0055G1858S0342.wav_00.npy\")\n",
"audio = torch.from_numpy(x.astype(np.float32))\n",
"\n",
"# audio, sampling_rate = load_wav_to_torch(\"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav\")\n",
"# audio_norm = audio / 32768.0\n",
"audio_norm = audio.unsqueeze(0)\n",
"spec = spectrogram(audio_norm, 1024, 256, 1024, center=False)\n",
"# spec = spec_to_mel()\n",
"spec = torch.squeeze(spec, 0)\n",
"mel = spec_to_mel(spec, 1024, 80, 16000, 0, None)\n",
"\n",
"fig = plt.figure(figsize=(10, 8))\n",
"ax2 = fig.add_subplot(211)\n",
"im = ax2.imshow(mel, interpolation=\"none\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"情感聚类"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# from sklearn import metrics\n",
"# from sklearn.mixture import GaussianMixture # 高斯混合模型\n",
"import os\n",
"import numpy as np\n",
"import librosa\n",
"import IPython.display as ipd\n",
"from random import sample\n",
"\n",
"embs = []\n",
"wavnames = []\n",
"emo_root_path = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
"wav_root_path = \"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\\"\n",
"for idx, emo_fpath in enumerate(sample(os.listdir(emo_root_path), 10000)):\n",
" if emo_fpath.endswith(\".npy\") and emo_fpath.startswith(\"emo-T\"):\n",
" embs.append(np.expand_dims(np.load(emo_root_path + emo_fpath), axis=0))\n",
" wav_fpath = wav_root_path + emo_fpath[9:14] + \"\\\\\" + emo_fpath.split(\"_00\")[0][4:]\n",
" wavnames.append(wav_fpath)\n",
"print(len(embs))\n",
"\n",
"\n",
"x = np.concatenate(embs, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 聚类算法类的数量\n",
"n_clusters = 20\n",
"from sklearn.cluster import *\n",
"# model = KMeans(n_clusters=n_clusters, random_state=10)\n",
"# model = DBSCAN(eps=0.002, min_samples=2)\n",
"# 可以自行尝试各种不同的聚类算法\n",
"# model = Birch(n_clusters= n_clusters, threshold= 0.2)\n",
"# model = SpectralClustering(n_clusters=n_clusters)\n",
"model = AgglomerativeClustering(n_clusters= n_clusters)\n",
"import random\n",
"\n",
"y_predict = model.fit_predict(x)\n",
"\n",
"def disp(wavname):\n",
" wav, sr =librosa.load(wavname, 16000)\n",
" display(ipd.Audio(wav, rate=sr))\n",
"\n",
"classes=[[] for i in range(y_predict.max()+1)]\n",
"\n",
"for idx, wavname in enumerate(wavnames):\n",
" classes[y_predict[idx]].append(wavname)\n",
"\n",
"for i in range(y_predict.max()+1):\n",
" print(\"类别:\", i, \"本类中样本数量:\", len(classes[i]))\n",
" \"\"\"每一个类只预览2条音频\"\"\"\n",
" for j in range(2):\n",
" idx = random.randint(0, len(classes[i]) - 1)\n",
" print(classes[i][idx])\n",
" disp(classes[i][idx])"
]
2023-02-04 14:13:38 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "mo",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "788ab866da3baa6c99886d56abb59fe71b6a552bf52c65473ecf96c784704db8"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}