FIx known issues

This commit is contained in:
babysor00 2022-02-20 11:56:58 +08:00
parent 19eaa68202
commit fad5023fca
4 changed files with 37 additions and 23 deletions

View File

@ -4,6 +4,8 @@ import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
import soundfile
import resampy
from ppg_extractor import load_model from ppg_extractor import load_model
import encoder.inference as Encoder import encoder.inference as Encoder
@ -32,11 +34,13 @@ def _compute_bnf(
bnf = ppg_model_local(wav_tensor, wav_length) bnf = ppg_model_local(wav_tensor, wav_length)
bnf_npy = bnf.squeeze(0).cpu().numpy() bnf_npy = bnf.squeeze(0).cpu().numpy()
np.save(output_fpath, bnf_npy, allow_pickle=False) np.save(output_fpath, bnf_npy, allow_pickle=False)
return bnf_npy, len(bnf_npy)
def _compute_f0_from_wav(wav, output_fpath): def _compute_f0_from_wav(wav, output_fpath):
"""Compute merged f0 values.""" """Compute merged f0 values."""
f0 = compute_f0(wav, SAMPLE_RATE) f0 = compute_f0(wav, SAMPLE_RATE)
np.save(output_fpath, f0, allow_pickle=False) np.save(output_fpath, f0, allow_pickle=False)
return f0, len(f0)
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
Encoder.set_model(encoder_model_local) Encoder.set_model(encoder_model_local)
@ -56,14 +60,22 @@ def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
embed = raw_embed / np.linalg.norm(raw_embed, 2) embed = raw_embed / np.linalg.norm(raw_embed, 2)
np.save(output_fpath, embed, allow_pickle=False) np.save(output_fpath, embed, allow_pickle=False)
return embed, len(embed)
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local): def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
wav = preprocess_wav(wav_path, 24000, False, False) # wav = preprocess_wav(wav_path)
# try:
wav, sr = soundfile.read(wav_path)
if len(wav) < sr:
return None, sr, len(wav)
if sr != SAMPLE_RATE:
wav = resampy.resample(wav, sr, SAMPLE_RATE)
sr = SAMPLE_RATE
utt_id = os.path.basename(wav_path).rstrip(".wav") utt_id = os.path.basename(wav_path).rstrip(".wav")
_compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) _, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
_compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) _, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) _, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
# Glob wav files # Glob wav files
@ -73,8 +85,7 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
# ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
ppg_model_local = None
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu") encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
if n_processes is None: if n_processes is None:
n_processes = cpu_count() n_processes = cpu_count()
@ -84,14 +95,15 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
job = Pool(n_processes).imap(func, wav_file_list) job = Pool(n_processes).imap(func, wav_file_list)
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
# finish processing and mark
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
for file in wav_file_list: for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
id = os.path.basename(file).rstrip(".wav") id = os.path.basename(file).rstrip(".f0.npy")
if id.endswith("1"): if id.endswith("01"):
d_fid_file.write(id + "\n") d_fid_file.write(id + "\n")
elif id.endswith("9"): elif id.endswith("09"):
e_fid_file.write(id + "\n") e_fid_file.write(id + "\n")
else: else:
t_fid_file.write(id + "\n") t_fid_file.write(id + "\n")

View File

@ -23,7 +23,7 @@ if __name__ == "__main__":
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
"Path to the output directory that will contain the mel spectrograms, the audios and the " "Path to the output directory that will contain the mel spectrograms, the audios and the "
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/") "embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\ parser.add_argument("-n", "--n_processes", type=int, default=16, help=\
"Number of processes in parallel.") "Number of processes in parallel.")
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\ # parser.add_argument("-s", "--skip_existing", action="store_true", help=\
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was " # "Whether to overwrite existing files with the same name. Useful if the preprocessing was "

2
run.py
View File

@ -39,7 +39,7 @@ def convert(args):
# Build models # Build models
print("Load PPG-model, PPG2Mel-model, Vocoder-model...") print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
ppg_model = load_model( ppg_model = load_model(
'./ppg_extractor/saved_models/24epoch.pt', Path('./ppg_extractor/saved_models/24epoch.pt'),
device, device,
) )
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device) ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)

View File

@ -8,6 +8,8 @@ from librosa.util import normalize
import os import os
SAMPLE_RATE=16000
def read_fids(fid_list_f): def read_fids(fid_list_f):
with open(fid_list_f, 'r') as f: with open(fid_list_f, 'r') as f:
fids = [l.strip().split()[0] for l in f if l.strip()] fids = [l.strip().split()[0] for l in f if l.strip()]
@ -71,8 +73,8 @@ class OneshotVcDataset(torch.utils.data.Dataset):
def compute_mel(self, wav_path): def compute_mel(self, wav_path):
audio, sr = load_wav(wav_path) audio, sr = load_wav(wav_path)
if sr != 24000: if sr != SAMPLE_RATE:
audio = resampy.resample(audio, sr, 24000) audio = resampy.resample(audio, sr, SAMPLE_RATE)
audio = audio / MAX_WAV_VALUE audio = audio / MAX_WAV_VALUE
audio = normalize(audio) * 0.95 audio = normalize(audio) * 0.95
audio = torch.FloatTensor(audio).unsqueeze(0) audio = torch.FloatTensor(audio).unsqueeze(0)
@ -80,9 +82,9 @@ class OneshotVcDataset(torch.utils.data.Dataset):
audio, audio,
n_fft=1024, n_fft=1024,
num_mels=80, num_mels=80,
sampling_rate=24000, sampling_rate=SAMPLE_RATE,
hop_size=240, hop_size=200,
win_size=1024, win_size=800,
fmin=0, fmin=0,
fmax=8000, fmax=8000,
) )
@ -112,7 +114,7 @@ class OneshotVcDataset(torch.utils.data.Dataset):
if self.min_max_norm_mel: if self.min_max_norm_mel:
mel = self.bin_level_min_max_norm(mel) mel = self.bin_level_min_max_norm(mel)
f0, ppg, mel = self._adjust_lengths(f0, ppg, mel) f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid)
spk_dvec = self.get_spk_dvec(fid) spk_dvec = self.get_spk_dvec(fid)
# 2. Convert f0 to continuous log-f0 and u/v flags # 2. Convert f0 to continuous log-f0 and u/v flags
@ -132,15 +134,15 @@ class OneshotVcDataset(torch.utils.data.Dataset):
return (ppg, lf0_uv, mel, spk_dvec, fid) return (ppg, lf0_uv, mel, spk_dvec, fid)
def check_lengths(self, f0, ppg, mel): def check_lengths(self, f0, ppg, mel, fid):
LEN_THRESH = 10 LEN_THRESH = 10
assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \ assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \
f"{abs(len(ppg) - len(f0))}" f"{abs(len(ppg) - len(f0))}: for file {fid}"
assert abs(len(mel) - len(f0)) <= LEN_THRESH, \ assert abs(len(mel) - len(f0)) <= LEN_THRESH, \
f"{abs(len(mel) - len(f0))}" f"{abs(len(mel) - len(f0))}: for file {fid}"
def _adjust_lengths(self, f0, ppg, mel): def _adjust_lengths(self, f0, ppg, mel, fid):
self.check_lengths(f0, ppg, mel) self.check_lengths(f0, ppg, mel, fid)
min_len = min( min_len = min(
len(f0), len(f0),
len(ppg), len(ppg),