mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
143 lines
4.5 KiB
Python
143 lines
4.5 KiB
Python
|
import time
|
||
|
import os
|
||
|
import argparse
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import glob
|
||
|
from pathlib import Path
|
||
|
from tqdm import tqdm
|
||
|
from ppg_extractor import load_model
|
||
|
import librosa
|
||
|
import soundfile as sf
|
||
|
from utils.load_yaml import HpsYaml
|
||
|
|
||
|
from encoder.audio import preprocess_wav
|
||
|
from encoder import inference as speacker_encoder
|
||
|
from vocoder.hifigan import inference as vocoder
|
||
|
from ppg2mel import MelDecoderMOLv2
|
||
|
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||
|
|
||
|
|
||
|
def _build_ppg2mel_model(model_config, model_file, device):
|
||
|
ppg2mel_model = MelDecoderMOLv2(
|
||
|
**model_config["model"]
|
||
|
).to(device)
|
||
|
ckpt = torch.load(model_file, map_location=device)
|
||
|
ppg2mel_model.load_state_dict(ckpt["model"])
|
||
|
ppg2mel_model.eval()
|
||
|
return ppg2mel_model
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def convert(args):
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
output_dir = args.output_dir
|
||
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
||
|
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
|
||
|
|
||
|
# Build models
|
||
|
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
|
||
|
ppg_model = load_model(
|
||
|
Path('./ppg_extractor/saved_models/24epoch.pt'),
|
||
|
device,
|
||
|
)
|
||
|
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
|
||
|
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
|
||
|
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
|
||
|
# Data related
|
||
|
ref_wav_path = args.ref_wav_path
|
||
|
ref_wav = preprocess_wav(ref_wav_path)
|
||
|
ref_fid = os.path.basename(ref_wav_path)[:-4]
|
||
|
|
||
|
# TODO: specify encoder
|
||
|
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||
|
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
|
||
|
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
|
||
|
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||
|
|
||
|
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
|
||
|
print(f"Number of source utterances: {len(source_file_list)}.")
|
||
|
|
||
|
total_rtf = 0.0
|
||
|
cnt = 0
|
||
|
for src_wav_path in tqdm(source_file_list):
|
||
|
# Load the audio to a numpy array:
|
||
|
src_wav, _ = librosa.load(src_wav_path, sr=16000)
|
||
|
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
|
||
|
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
|
||
|
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
|
||
|
|
||
|
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||
|
min_len = min(ppg.shape[1], len(lf0_uv))
|
||
|
|
||
|
ppg = ppg[:, :min_len]
|
||
|
lf0_uv = lf0_uv[:min_len]
|
||
|
|
||
|
start = time.time()
|
||
|
_, mel_pred, att_ws = ppg2mel_model.inference(
|
||
|
ppg,
|
||
|
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||
|
spembs=ref_spk_dvec,
|
||
|
)
|
||
|
src_fid = os.path.basename(src_wav_path)[:-4]
|
||
|
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
|
||
|
mel_len = mel_pred.shape[0]
|
||
|
rtf = (time.time() - start) / (0.01 * mel_len)
|
||
|
total_rtf += rtf
|
||
|
cnt += 1
|
||
|
# continue
|
||
|
mel_pred= mel_pred.transpose(0, 1)
|
||
|
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
|
||
|
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
|
||
|
|
||
|
print("RTF:")
|
||
|
print(total_rtf / cnt)
|
||
|
|
||
|
|
||
|
def get_parser():
|
||
|
parser = argparse.ArgumentParser(description="Conversion from wave input")
|
||
|
parser.add_argument(
|
||
|
"--wav_dir",
|
||
|
type=str,
|
||
|
default=None,
|
||
|
required=True,
|
||
|
help="Source wave directory.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--ref_wav_path",
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Reference wave file path.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--ppg2mel_model_train_config", "-c",
|
||
|
type=str,
|
||
|
default=None,
|
||
|
required=True,
|
||
|
help="Training config file (yaml file)",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--ppg2mel_model_file", "-m",
|
||
|
type=str,
|
||
|
default=None,
|
||
|
required=True,
|
||
|
help="ppg2mel model checkpoint file path"
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--output_dir", "-o",
|
||
|
type=str,
|
||
|
default="vc_gens_vctk_oneshot",
|
||
|
help="Output folder to save the converted wave."
|
||
|
)
|
||
|
|
||
|
return parser
|
||
|
|
||
|
def main():
|
||
|
parser = get_parser()
|
||
|
args = parser.parse_args()
|
||
|
convert(args)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|