mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
547 lines
19 KiB
Plaintext
Vendored
547 lines
19 KiB
Plaintext
Vendored
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from utils.hparams import load_hparams_json\n",
|
||
"from utils.util import intersperse\n",
|
||
"import json\n",
|
||
"from models.synthesizer.models.vits import Vits\n",
|
||
"import torch\n",
|
||
"import numpy as np\n",
|
||
"import IPython.display as ipd\n",
|
||
"from models.synthesizer.utils.symbols import symbols\n",
|
||
"from models.synthesizer.utils.text import text_to_sequence\n",
|
||
"\n",
|
||
"\n",
|
||
"hps = load_hparams_json(\"data/ckpt/synthesizer/vits5/config.json\")\n",
|
||
"print(hps.train)\n",
|
||
"model = Vits(\n",
|
||
" len(symbols),\n",
|
||
" hps[\"data\"][\"filter_length\"] // 2 + 1,\n",
|
||
" hps[\"train\"][\"segment_size\"] // hps[\"data\"][\"hop_length\"],\n",
|
||
" n_speakers=hps[\"data\"][\"n_speakers\"],\n",
|
||
" **hps[\"model\"])\n",
|
||
"_ = model.eval()\n",
|
||
"device = torch.device(\"cpu\")\n",
|
||
"checkpoint = torch.load(str(\"data/ckpt/synthesizer/vits5/G_56000.pth\"), map_location=device)\n",
|
||
"if \"model_state\" in checkpoint:\n",
|
||
" state = checkpoint[\"model_state\"]\n",
|
||
"else:\n",
|
||
" state = checkpoint[\"model\"]\n",
|
||
"model.load_state_dict(state, strict=False)\n",
|
||
"\n",
|
||
"# 随机抽取情感参考音频的根目录\n",
|
||
"random_emotion_root = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
|
||
"import random, re\n",
|
||
"from pypinyin import lazy_pinyin, Style\n",
|
||
"\n",
|
||
"import os\n",
|
||
"\n",
|
||
"def tts(txt, emotion, sid=0):\n",
|
||
" txt = \" \".join(lazy_pinyin(txt, style=Style.TONE3, neutral_tone_with_five=False))\n",
|
||
" text_norm = text_to_sequence(txt, hps[\"data\"][\"text_cleaners\"])\n",
|
||
" # if hps[\"data\"][\"add_blank\"]:\n",
|
||
" # text_norm = intersperse(text_norm, 0)\n",
|
||
" stn_tst = torch.LongTensor(text_norm)\n",
|
||
"\n",
|
||
" with torch.no_grad(): #inference mode\n",
|
||
" x_tst = stn_tst.unsqueeze(0)\n",
|
||
" x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n",
|
||
" sid = torch.LongTensor([sid])\n",
|
||
" if emotion.endswith(\"wav\"):\n",
|
||
" from models.synthesizer.preprocess_audio import extract_emo\n",
|
||
" import librosa\n",
|
||
" wav, sr = librosa.load(emotion, 16000)\n",
|
||
" emo = torch.FloatTensor(extract_emo(np.expand_dims(wav, 0), sr, embeddings=True))\n",
|
||
" elif emotion == \"random_sample\":\n",
|
||
" rand_emo = random.sample(os.listdir(random_emotion_root), 1)[0]\n",
|
||
" print(rand_emo)\n",
|
||
" emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{rand_emo}\")).unsqueeze(0)\n",
|
||
" elif emotion.endswith(\"npy\"):\n",
|
||
" print(emotion)\n",
|
||
" emo = torch.FloatTensor(np.load(f\"{random_emotion_root}\\\\{emotion}\")).unsqueeze(0)\n",
|
||
" else:\n",
|
||
" print(\"emotion参数不正确\")\n",
|
||
"\n",
|
||
" audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1, emo=emo)[0][0,0].data.float().numpy()\n",
|
||
" ipd.display(ipd.Audio(audio, rate=hps[\"data\"][\"sampling_rate\"], normalize=False))\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"推理:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
|
||
"#正常: \n",
|
||
"tts(txt, emotion='emo-T0055G4906S0052.wav_00.npy', sid=100)\n",
|
||
"#快速:emo-T0055G2323S0179.wav_00.npy\n",
|
||
"\n",
|
||
"#难过:\n",
|
||
"tts(txt, emotion='emo-15_4581_20170825202626.wav_00.npy', sid=100)\n",
|
||
"\n",
|
||
"#开心:T0055G2412S0498.wav\n",
|
||
"tts(txt, emotion='emo-T0055G2412S0498.wav_00.npy', sid=100)\n",
|
||
"\n",
|
||
"#愤怒 T0055G1371S0363.wav T0055G1344S0160.wav\n",
|
||
"tts(txt, emotion='emo-T0055G1344S0160.wav_00.npy', sid=100)\n",
|
||
"\n",
|
||
"#疲惫\n",
|
||
"tts(txt, emotion='emo-T0055G2294S0476.wav_00.npy', sid=100)\n",
|
||
"\n",
|
||
"#着急\n",
|
||
"tts(txt, emotion='emo-T0055G1671S0170.wav_00.npy', sid=100)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
|
||
"tts(txt, emotion='random_sample', sid=100)\n",
|
||
"tts(txt, emotion='random_sample', sid=100)\n",
|
||
"tts(txt, emotion='random_sample', sid=100)\n",
|
||
"tts(txt, emotion='random_sample', sid=100)\n",
|
||
"tts(txt, emotion='random_sample', sid=100)\n",
|
||
"tts(txt, emotion='random_sample', sid=100)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"txt = \"我们将其拓展到文本驱动数字人形象领域\"\n",
|
||
"types = [\"平淡\", \"激动\", \"疲惫\", \"兴奋\", \"沮丧\", \"开心\"]\n",
|
||
"for t in types:\n",
|
||
" print(t)\n",
|
||
" tts(txt, emotion=f'C:\\\\Users\\\\babys\\\\Music\\\\{t}.wav', sid=100)\n",
|
||
"# tts(txt, emotion='D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav', sid=5)"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"预处理:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from models.synthesizer.preprocess import preprocess_dataset\n",
|
||
"from pathlib import Path\n",
|
||
"from utils.hparams import HParams\n",
|
||
"datasets_root = Path(\"../audiodata/\")\n",
|
||
"hparams = HParams(\n",
|
||
" n_fft = 1024, # filter_length\n",
|
||
" num_mels = 80,\n",
|
||
" hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)\n",
|
||
" win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)\n",
|
||
" fmin = 55,\n",
|
||
" min_level_db = -100,\n",
|
||
" ref_level_db = 20,\n",
|
||
" max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.\n",
|
||
" sample_rate = 16000,\n",
|
||
" rescale = True,\n",
|
||
" max_mel_frames = 900,\n",
|
||
" rescaling_max = 0.9, \n",
|
||
" preemphasis = 0.97, # Filter coefficient to use if preemphasize is True\n",
|
||
" preemphasize = True,\n",
|
||
" ### Mel Visualization and Griffin-Lim\n",
|
||
" signal_normalization = True,\n",
|
||
"\n",
|
||
" utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded\n",
|
||
" ### Audio processing options\n",
|
||
" fmax = 7600, # Should not exceed (sample_rate // 2)\n",
|
||
" allow_clipping_in_normalization = True, # Used when signal_normalization = True\n",
|
||
" clip_mels_length = True, # If true, discards samples exceeding max_mel_frames\n",
|
||
" use_lws = False, # \"Fast spectrogram phase recovery using local weighted sums\"\n",
|
||
" symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,\n",
|
||
" # and [0, max_abs_value] if False\n",
|
||
" trim_silence = False, # Use with sample_rate of 16000 for best results\n",
|
||
"\n",
|
||
")\n",
|
||
"preprocess_dataset(datasets_root=datasets_root, \n",
|
||
" out_dir=datasets_root.joinpath(\"SV2TTS\", \"synthesizer\"),\n",
|
||
" n_processes=8,\n",
|
||
" skip_existing=True, \n",
|
||
" hparams=hparams, \n",
|
||
" no_alignments=False, \n",
|
||
" dataset=\"aidatatang_200zh\", \n",
|
||
" emotion_extract=True)"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"训练:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from models.synthesizer.train_vits import run\n",
|
||
"from pathlib import Path\n",
|
||
"from utils.hparams import HParams\n",
|
||
"import torch, os\n",
|
||
"import torch.multiprocessing as mp\n",
|
||
"\n",
|
||
"datasets_root = Path(\"../audiodata/SV2TTS/synthesizer\")\n",
|
||
"hparams= HParams(\n",
|
||
" model_dir = \"data/ckpt/synthesizer/vits\",\n",
|
||
")\n",
|
||
"hparams.loadJson(Path(hparams.model_dir).joinpath(\"config.json\"))\n",
|
||
"hparams.data[\"training_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
|
||
"hparams.data[\"validation_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n",
|
||
"hparams.data[\"datasets_root\"] = str(datasets_root)\n",
|
||
"\n",
|
||
"n_gpus = torch.cuda.device_count()\n",
|
||
"# for spawn\n",
|
||
"os.environ['MASTER_ADDR'] = 'localhost'\n",
|
||
"os.environ['MASTER_PORT'] = '8899'\n",
|
||
"mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hparams))"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"挑选只有对应emo文件的meta数据"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"import os\n",
|
||
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
|
||
"dict_info = []\n",
|
||
"with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
|
||
" for raw in dict_meta:\n",
|
||
" if not raw:\n",
|
||
" continue\n",
|
||
" v = raw.split(\"|\")[0].replace(\"audio\",\"emo\")\n",
|
||
" emo_fpath = root.joinpath(\"emo\").joinpath(v)\n",
|
||
" if emo_fpath.exists():\n",
|
||
" dict_info.append(raw)\n",
|
||
" # else:\n",
|
||
" # print(emo_fpath)\n",
|
||
"# Iterate over each wav\n",
|
||
"meta2 = Path('../audiodata/SV2TTS/synthesizer/train2.txt')\n",
|
||
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
|
||
"for new_info in dict_info:\n",
|
||
" metadata_file.write(new_info)\n",
|
||
"metadata_file.close()"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"从训练集中抽取10%作为测试集"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
|
||
"dict_info1 = []\n",
|
||
"dict_info2 = []\n",
|
||
"count = 1\n",
|
||
"with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
|
||
" for raw in dict_meta:\n",
|
||
" if not raw:\n",
|
||
" continue\n",
|
||
" if count % 10 == 0:\n",
|
||
" dict_info2.append(raw)\n",
|
||
" else:\n",
|
||
" dict_info1.append(raw)\n",
|
||
" count += 1\n",
|
||
"# Iterate over each wav\n",
|
||
"meta1 = Path('../audiodata/SV2TTS/synthesizer/train1.txt')\n",
|
||
"metadata_file = meta1.open(\"w\", encoding=\"utf-8\")\n",
|
||
"for new_info in dict_info1:\n",
|
||
" metadata_file.write(new_info)\n",
|
||
"metadata_file.close()\n",
|
||
"\n",
|
||
"meta2 = Path('../audiodata/SV2TTS/synthesizer/eval.txt')\n",
|
||
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
|
||
"for new_info in dict_info2:\n",
|
||
" metadata_file.write(new_info)\n",
|
||
"metadata_file.close()"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"evaluation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from pathlib import Path\n",
|
||
"root = Path('../audiodata/SV2TTS/synthesizer')\n",
|
||
"spks = []\n",
|
||
"spk_id = {}\n",
|
||
"rows = []\n",
|
||
"with open(root.joinpath(\"eval.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n",
|
||
" for raw in dict_meta:\n",
|
||
" speaker_name = raw.split(\"-\")[1][6:10]\n",
|
||
" if speaker_name not in spk_id:\n",
|
||
" spks.append(speaker_name)\n",
|
||
" spk_id[speaker_name] = 1\n",
|
||
" rows.append(raw)\n",
|
||
"i = 0\n",
|
||
"spks.sort()\n",
|
||
"\n",
|
||
"for sp in spks:\n",
|
||
" spk_id[sp] = str(i)\n",
|
||
" i = i + 1\n",
|
||
"print(len(spks))\n",
|
||
"meta2 = Path('../audiodata/SV2TTS/synthesizer/eval2.txt')\n",
|
||
"metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n",
|
||
"for row in rows:\n",
|
||
" speaker_n = row.split(\"-\")[1][6:10]\n",
|
||
" metadata_file.write(row.strip()+\"|\"+spk_id[speaker_n]+\"\\n\")\n",
|
||
"metadata_file.close()\n"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"[Not Recommended]\n",
|
||
"Try to transcript map to detailed format:\n",
|
||
"ni3 hao3 -> n i3 <pad> h ao3\n",
|
||
"\n",
|
||
"After couple of tests, I think this method will not improve the quality of result and may cause the crash of monotonic alignment."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"from pathlib import Path\n",
|
||
"datasets_root = Path(\"../audiodata/SV2TTS/synthesizer/\")\n",
|
||
"\n",
|
||
"dictionary_fp = Path(\"../audiodata/ProDiff/processed/mandarin_pinyin.dict\")\n",
|
||
"dict_map = {}\n",
|
||
"for l in open(dictionary_fp, encoding='utf-8').readlines():\n",
|
||
" item = l.split(\"\\t\")\n",
|
||
" dict_map[item[0]] = item[1].replace(\"\\n\",\"\")\n",
|
||
"\n",
|
||
"with datasets_root.joinpath('train2.txt').open(\"w+\", encoding='utf-8') as f:\n",
|
||
" for l in open(datasets_root.joinpath('train.txt'), encoding='utf-8').readlines():\n",
|
||
" items = l.strip().replace(\"\\n\",\"\").replace(\"\\t\",\" \").split(\"|\")\n",
|
||
" phs_str = \"\"\n",
|
||
" for word in items[5].split(\" \"):\n",
|
||
" if word in dict_map:\n",
|
||
" phs_str += dict_map[word] \n",
|
||
" else:\n",
|
||
" phs_str += word\n",
|
||
" phs_str += \" _ \"\n",
|
||
" items[5] = phs_str\n",
|
||
" # if not os.path.exists(mfa_input_root.joinpath('train.txt')):\n",
|
||
" # with open(mfa_input_root.joinpath(fileName + 'lab'), 'w+', encoding=\"utf-8\") as f:\n",
|
||
" f.write(\"|\".join(items) + \"\\n\")"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"预处理后的数据可视化"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import librosa.display\n",
|
||
"import librosa, torch\n",
|
||
"import numpy as np\n",
|
||
"from utils.audio_utils import spectrogram, mel_spectrogram, load_wav_to_torch, spec_to_mel\n",
|
||
"\n",
|
||
"# x, sr = librosa.load(\"D:\\audiodata\\SV2TTS\\synthesizer\\audio\\audio-T0055G2333S0196.wav_00.npy\")\n",
|
||
"x = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\audio\\\\audio-T0055G1858S0342.wav_00.npy\")\n",
|
||
"\n",
|
||
"plt.figure(figsize=(14, 5))\n",
|
||
"librosa.display.waveplot(x)\n",
|
||
"\n",
|
||
"X = librosa.stft(x)\n",
|
||
"Xdb = librosa.amplitude_to_db(abs(X))\n",
|
||
"plt.figure(figsize=(14, 5))\n",
|
||
"librosa.display.specshow(Xdb, x_axis='time', y_axis='hz')\n",
|
||
"\n",
|
||
"# spectrogram = np.load(\"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\mels\\\\mel-T0055G1858S0342.wav_00.npy\")\n",
|
||
"audio = torch.from_numpy(x.astype(np.float32))\n",
|
||
"\n",
|
||
"# audio, sampling_rate = load_wav_to_torch(\"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G1858\\\\T0055G1858S0342.wav\")\n",
|
||
"# audio_norm = audio / 32768.0\n",
|
||
"audio_norm = audio.unsqueeze(0)\n",
|
||
"spec = spectrogram(audio_norm, 1024, 256, 1024, center=False)\n",
|
||
"# spec = spec_to_mel()\n",
|
||
"spec = torch.squeeze(spec, 0)\n",
|
||
"mel = spec_to_mel(spec, 1024, 80, 16000, 0, None)\n",
|
||
"\n",
|
||
"fig = plt.figure(figsize=(10, 8))\n",
|
||
"ax2 = fig.add_subplot(211)\n",
|
||
"im = ax2.imshow(mel, interpolation=\"none\")"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"情感聚类"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"# from sklearn import metrics\n",
|
||
"# from sklearn.mixture import GaussianMixture # 高斯混合模型\n",
|
||
"import os\n",
|
||
"import numpy as np\n",
|
||
"import librosa\n",
|
||
"import IPython.display as ipd\n",
|
||
"from random import sample\n",
|
||
"\n",
|
||
"embs = []\n",
|
||
"wavnames = []\n",
|
||
"emo_root_path = \"D:\\\\audiodata\\\\SV2TTS\\\\synthesizer\\\\emo\\\\\"\n",
|
||
"wav_root_path = \"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\\"\n",
|
||
"for idx, emo_fpath in enumerate(sample(os.listdir(emo_root_path), 10000)):\n",
|
||
" if emo_fpath.endswith(\".npy\") and emo_fpath.startswith(\"emo-T\"):\n",
|
||
" embs.append(np.expand_dims(np.load(emo_root_path + emo_fpath), axis=0))\n",
|
||
" wav_fpath = wav_root_path + emo_fpath[9:14] + \"\\\\\" + emo_fpath.split(\"_00\")[0][4:]\n",
|
||
" wavnames.append(wav_fpath)\n",
|
||
"print(len(embs))\n",
|
||
"\n",
|
||
"\n",
|
||
"x = np.concatenate(embs, axis=0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 聚类算法类的数量\n",
|
||
"n_clusters = 20\n",
|
||
"from sklearn.cluster import *\n",
|
||
"# model = KMeans(n_clusters=n_clusters, random_state=10)\n",
|
||
"# model = DBSCAN(eps=0.002, min_samples=2)\n",
|
||
"# 可以自行尝试各种不同的聚类算法\n",
|
||
"# model = Birch(n_clusters= n_clusters, threshold= 0.2)\n",
|
||
"# model = SpectralClustering(n_clusters=n_clusters)\n",
|
||
"model = AgglomerativeClustering(n_clusters= n_clusters)\n",
|
||
"import random\n",
|
||
"\n",
|
||
"y_predict = model.fit_predict(x)\n",
|
||
"\n",
|
||
"def disp(wavname):\n",
|
||
" wav, sr =librosa.load(wavname, 16000)\n",
|
||
" display(ipd.Audio(wav, rate=sr))\n",
|
||
"\n",
|
||
"classes=[[] for i in range(y_predict.max()+1)]\n",
|
||
"\n",
|
||
"for idx, wavname in enumerate(wavnames):\n",
|
||
" classes[y_predict[idx]].append(wavname)\n",
|
||
"\n",
|
||
"for i in range(y_predict.max()+1):\n",
|
||
" print(\"类别:\", i, \"本类中样本数量:\", len(classes[i]))\n",
|
||
" \"\"\"每一个类只预览2条音频\"\"\"\n",
|
||
" for j in range(2):\n",
|
||
" idx = random.randint(0, len(classes[i]) - 1)\n",
|
||
" print(classes[i][idx])\n",
|
||
" disp(classes[i][idx])"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "mo",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.7"
|
||
},
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "788ab866da3baa6c99886d56abb59fe71b6a552bf52c65473ecf96c784704db8"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|