Fix sample issues

This commit is contained in:
babysor00 2022-03-02 23:15:37 +08:00
parent dd3abebc4d
commit 6befb700e9
6 changed files with 17 additions and 156 deletions

View File

@ -1,142 +0,0 @@
import time
import os
import argparse
import torch
import numpy as np
import glob
from pathlib import Path
from tqdm import tqdm
from ppg_extractor import load_model
import librosa
import soundfile as sf
from utils.load_yaml import HpsYaml
from encoder.audio import preprocess_wav
from encoder import inference as speacker_encoder
from vocoder.hifigan import inference as vocoder
from ppg2mel import MelDecoderMOLv2
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
def _build_ppg2mel_model(model_config, model_file, device):
ppg2mel_model = MelDecoderMOLv2(
**model_config["model"]
).to(device)
ckpt = torch.load(model_file, map_location=device)
ppg2mel_model.load_state_dict(ckpt["model"])
ppg2mel_model.eval()
return ppg2mel_model
@torch.no_grad()
def convert(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
# Build models
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
ppg_model = load_model(
'./ppg_extractor/saved_models/24epoch.pt',
device,
)
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
# vocoder.load_model('./vocoder/saved_models/pretrained/g_02830000.pt')
# Data related
ref_wav_path = args.ref_wav_path
ref_wav = preprocess_wav(ref_wav_path)
ref_fid = os.path.basename(ref_wav_path)[:-4]
# TODO: specify encoder
speacker_encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
print(f"Number of source utterances: {len(source_file_list)}.")
total_rtf = 0.0
cnt = 0
for src_wav_path in tqdm(source_file_list):
# Load the audio to a numpy array:
src_wav, _ = librosa.load(src_wav_path, sr=16000)
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
min_len = min(ppg.shape[1], len(lf0_uv))
ppg = ppg[:, :min_len]
lf0_uv = lf0_uv[:min_len]
start = time.time()
_, mel_pred, att_ws = ppg2mel_model.inference(
ppg,
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
spembs=ref_spk_dvec,
)
src_fid = os.path.basename(src_wav_path)[:-4]
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
mel_len = mel_pred.shape[0]
rtf = (time.time() - start) / (0.01 * mel_len)
total_rtf += rtf
cnt += 1
# continue
mel_pred= mel_pred.transpose(0, 1)
y = vocoder.infer_waveform(mel_pred.cpu())
sf.write(wav_fname, y.squeeze(), 16000, "PCM_16")
print("RTF:")
print(total_rtf / cnt)
def get_parser():
parser = argparse.ArgumentParser(description="Conversion from wave input")
parser.add_argument(
"--wav_dir",
type=str,
default=None,
required=True,
help="Source wave directory.",
)
parser.add_argument(
"--ref_wav_path",
type=str,
required=True,
help="Reference wave file path.",
)
parser.add_argument(
"--ppg2mel_model_train_config", "-c",
type=str,
default=None,
required=True,
help="Training config file (yaml file)",
)
parser.add_argument(
"--ppg2mel_model_file", "-m",
type=str,
default=None,
required=True,
help="ppg2mel model checkpoint file path"
)
parser.add_argument(
"--output_dir", "-o",
type=str,
default="vc_gens_vctk_oneshot",
help="Output folder to save the converted wave."
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
convert(args)
if __name__ == "__main__":
main()

10
run.py
View File

@ -43,15 +43,15 @@ def convert(args):
device,
)
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
# vocoder.load_model('./vocoder/saved_models/pretrained/g_02830000.pt')
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
# Data related
ref_wav_path = args.ref_wav_path
ref_wav = preprocess_wav(ref_wav_path)
ref_fid = os.path.basename(ref_wav_path)[:-4]
# TODO: specify encoder
speacker_encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
@ -88,8 +88,8 @@ def convert(args):
cnt += 1
# continue
mel_pred= mel_pred.transpose(0, 1)
y = vocoder.infer_waveform(mel_pred.cpu())
sf.write(wav_fname, y.squeeze(), 16000, "PCM_16")
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
print("RTF:")
print(total_rtf / cnt)

View File

@ -286,7 +286,7 @@ class Toolbox:
self.ui.set_loading(i, seq_len)
if self.ui.current_vocoder_fpath is not None:
self.ui.log("")
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
else:
self.ui.log("Waveform generation with Griffin-Lim... ")
wav = Synthesizer.griffin_lim(spec)
@ -297,7 +297,7 @@ class Toolbox:
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
b_starts = np.concatenate(([0], b_ends[:-1]))
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
# Trim excessive silences
@ -306,7 +306,7 @@ class Toolbox:
# Play it
wav = wav / np.abs(wav).max() * 0.97
self.ui.play(wav, Synthesizer.sample_rate)
self.ui.play(wav, sample_rate)
# Name it (history displayed in combobox)
# TODO better naming for the combobox items?

View File

@ -7,6 +7,7 @@ from vocoder.hifigan.env import AttrDict
from vocoder.hifigan.models import Generator
generator = None # type: Generator
output_sample_rate = None
_device = None
@ -18,8 +19,8 @@ def load_checkpoint(filepath, device):
return checkpoint_dict
def load_model(weights_fpath, config_fpath="./vocoder/saved_models/pretrained/config.json", verbose=True):
global generator, _device
def load_model(weights_fpath, config_fpath="./vocoder/saved_models/24k/config.json", verbose=True):
global generator, _device, output_sample_rate
if verbose:
print("Building hifigan")
@ -28,6 +29,7 @@ def load_model(weights_fpath, config_fpath="./vocoder/saved_models/pretrained/co
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
output_sample_rate = h.sampling_rate
torch.manual_seed(h.seed)
if torch.cuda.is_available():
@ -62,5 +64,5 @@ def infer_waveform(mel, progress_callback=None):
audio = y_g_hat.squeeze()
audio = audio.cpu().numpy()
return audio
return audio, output_sample_rate

View File

@ -61,4 +61,4 @@ def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
mel = mel / hp.mel_max_abs_value
mel = torch.from_numpy(mel[None, ...])
wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback)
return wav
return wav, hp.sample_rate

View File

@ -107,14 +107,15 @@ def webApp():
embeds = [embed] * len(texts)
specs = current_synt.synthesize_spectrograms(texts, embeds)
spec = np.concatenate(specs, axis=1)
sample_rate = Synthesizer.sample_rate
if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN":
wav = rnn_vocoder.infer_waveform(spec)
else:
wav = gan_vocoder.infer_waveform(spec)
wav, sample_rate = gan_vocoder.infer_waveform(spec)
# Return cooked wav
out = io.BytesIO()
write(out, Synthesizer.sample_rate, wav.astype(np.float32))
write(out, sample_rate, wav.astype(np.float32))
return Response(out, mimetype="audio/wav")
@app.route('/', methods=['GET'])