FIx known issues

This commit is contained in:
babysor00 2022-02-20 11:56:58 +08:00
parent 19eaa68202
commit fad5023fca
4 changed files with 37 additions and 23 deletions

View File

@ -4,6 +4,8 @@ import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
import soundfile
import resampy
from ppg_extractor import load_model
import encoder.inference as Encoder
@ -32,11 +34,13 @@ def _compute_bnf(
bnf = ppg_model_local(wav_tensor, wav_length)
bnf_npy = bnf.squeeze(0).cpu().numpy()
np.save(output_fpath, bnf_npy, allow_pickle=False)
return bnf_npy, len(bnf_npy)
def _compute_f0_from_wav(wav, output_fpath):
"""Compute merged f0 values."""
f0 = compute_f0(wav, SAMPLE_RATE)
np.save(output_fpath, f0, allow_pickle=False)
return f0, len(f0)
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
Encoder.set_model(encoder_model_local)
@ -56,14 +60,22 @@ def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
embed = raw_embed / np.linalg.norm(raw_embed, 2)
np.save(output_fpath, embed, allow_pickle=False)
return embed, len(embed)
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
wav = preprocess_wav(wav_path, 24000, False, False)
# wav = preprocess_wav(wav_path)
# try:
wav, sr = soundfile.read(wav_path)
if len(wav) < sr:
return None, sr, len(wav)
if sr != SAMPLE_RATE:
wav = resampy.resample(wav, sr, SAMPLE_RATE)
sr = SAMPLE_RATE
utt_id = os.path.basename(wav_path).rstrip(".wav")
_compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
_compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
# Glob wav files
@ -73,8 +85,7 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
# ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
ppg_model_local = None
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
if n_processes is None:
n_processes = cpu_count()
@ -84,14 +95,15 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
job = Pool(n_processes).imap(func, wav_file_list)
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
# finish processing and mark
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
for file in wav_file_list:
id = os.path.basename(file).rstrip(".wav")
if id.endswith("1"):
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
id = os.path.basename(file).rstrip(".f0.npy")
if id.endswith("01"):
d_fid_file.write(id + "\n")
elif id.endswith("9"):
elif id.endswith("09"):
e_fid_file.write(id + "\n")
else:
t_fid_file.write(id + "\n")

View File

@ -23,7 +23,7 @@ if __name__ == "__main__":
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
"Path to the output directory that will contain the mel spectrograms, the audios and the "
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
parser.add_argument("-n", "--n_processes", type=int, default=16, help=\
"Number of processes in parallel.")
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "

2
run.py
View File

@ -39,7 +39,7 @@ def convert(args):
# Build models
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
ppg_model = load_model(
'./ppg_extractor/saved_models/24epoch.pt',
Path('./ppg_extractor/saved_models/24epoch.pt'),
device,
)
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)

View File

@ -8,6 +8,8 @@ from librosa.util import normalize
import os
SAMPLE_RATE=16000
def read_fids(fid_list_f):
with open(fid_list_f, 'r') as f:
fids = [l.strip().split()[0] for l in f if l.strip()]
@ -71,8 +73,8 @@ class OneshotVcDataset(torch.utils.data.Dataset):
def compute_mel(self, wav_path):
audio, sr = load_wav(wav_path)
if sr != 24000:
audio = resampy.resample(audio, sr, 24000)
if sr != SAMPLE_RATE:
audio = resampy.resample(audio, sr, SAMPLE_RATE)
audio = audio / MAX_WAV_VALUE
audio = normalize(audio) * 0.95
audio = torch.FloatTensor(audio).unsqueeze(0)
@ -80,9 +82,9 @@ class OneshotVcDataset(torch.utils.data.Dataset):
audio,
n_fft=1024,
num_mels=80,
sampling_rate=24000,
hop_size=240,
win_size=1024,
sampling_rate=SAMPLE_RATE,
hop_size=200,
win_size=800,
fmin=0,
fmax=8000,
)
@ -112,7 +114,7 @@ class OneshotVcDataset(torch.utils.data.Dataset):
if self.min_max_norm_mel:
mel = self.bin_level_min_max_norm(mel)
f0, ppg, mel = self._adjust_lengths(f0, ppg, mel)
f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid)
spk_dvec = self.get_spk_dvec(fid)
# 2. Convert f0 to continuous log-f0 and u/v flags
@ -132,15 +134,15 @@ class OneshotVcDataset(torch.utils.data.Dataset):
return (ppg, lf0_uv, mel, spk_dvec, fid)
def check_lengths(self, f0, ppg, mel):
def check_lengths(self, f0, ppg, mel, fid):
LEN_THRESH = 10
assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \
f"{abs(len(ppg) - len(f0))}"
f"{abs(len(ppg) - len(f0))}: for file {fid}"
assert abs(len(mel) - len(f0)) <= LEN_THRESH, \
f"{abs(len(mel) - len(f0))}"
f"{abs(len(mel) - len(f0))}: for file {fid}"
def _adjust_lengths(self, f0, ppg, mel):
self.check_lengths(f0, ppg, mel)
def _adjust_lengths(self, f0, ppg, mel, fid):
self.check_lengths(f0, ppg, mel, fid)
min_len = min(
len(f0),
len(ppg),