pull/795/head
babysor00 2023-02-01 19:59:15 +08:00
parent 74a3fc97d0
commit e469bd06ae
23 changed files with 248 additions and 98 deletions

12
.vscode/launch.json vendored
View File

@ -15,7 +15,8 @@
"name": "Python: Vocoder Preprocess", "name": "Python: Vocoder Preprocess",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "vocoder_preprocess.py", "program": "control\\cli\\vocoder_preprocess.py",
"cwd": "${workspaceFolder}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["..\\audiodata"] "args": ["..\\audiodata"]
}, },
@ -23,7 +24,8 @@
"name": "Python: Vocoder Train", "name": "Python: Vocoder Train",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "vocoder_train.py", "program": "control\\cli\\vocoder_train.py",
"cwd": "${workspaceFolder}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["dev", "..\\audiodata"] "args": ["dev", "..\\audiodata"]
}, },
@ -32,6 +34,7 @@
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "demo_toolbox.py", "program": "demo_toolbox.py",
"cwd": "${workspaceFolder}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["-d","..\\audiodata"] "args": ["-d","..\\audiodata"]
}, },
@ -40,6 +43,7 @@
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "demo_toolbox.py", "program": "demo_toolbox.py",
"cwd": "${workspaceFolder}",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["-d","..\\audiodata","-vc"] "args": ["-d","..\\audiodata","-vc"]
}, },
@ -47,9 +51,9 @@
"name": "Python: Synth Train", "name": "Python: Synth Train",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "synthesizer_train.py", "program": "train.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["my_run", "..\\"] "args": ["--type", "synth", "..\\audiodata\\SV2TTS\\synthesizer"]
}, },
{ {
"name": "Python: PPG Convert", "name": "Python: PPG Convert",

0
control/cli/__init__.py Normal file
View File

View File

@ -14,7 +14,7 @@ if __name__ == "__main__":
"Path to the synthesizer training data that contains the audios and the train.txt file. " "Path to the synthesizer training data that contains the audios and the train.txt file. "
"If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.") "If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
parser.add_argument("-e", "--encoder_model_fpath", type=Path, parser.add_argument("-e", "--encoder_model_fpath", type=Path,
default="encoder/saved_models/pretrained.pt", help=\ default="data/ckpt/encoder/pretrained.pt", help=\
"Path your trained encoder model.") "Path your trained encoder model.")
parser.add_argument("-n", "--n_processes", type=int, default=4, help= \ parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
"Number of parallel processes. An encoder is created for each, so you may need to lower " "Number of parallel processes. An encoder is created for each, so you may need to lower "

View File

@ -3,8 +3,7 @@ from models.synthesizer.train import train
from utils.argutils import print_args from utils.argutils import print_args
import argparse import argparse
def new_train():
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("run_id", type=str, help= \ parser.add_argument("run_id", type=str, help= \
"Name for this model instance. If a model state from the same run ID was previously " "Name for this model instance. If a model state from the same run ID was previously "
@ -13,7 +12,7 @@ if __name__ == "__main__":
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \ parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the synthesizer directory that contains the ground truth mel spectrograms, " "Path to the synthesizer directory that contains the ground truth mel spectrograms, "
"the wavs and the embeds.") "the wavs and the embeds.")
parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\ parser.add_argument("-m", "--models_dir", type=str, default=f"data/ckpt/synthesizer/", help=\
"Path to the output directory that will contain the saved model weights and the logs.") "Path to the output directory that will contain the saved model weights and the logs.")
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \ parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
"Number of steps between updates of the model on the disk. Set to 0 to never save the " "Number of steps between updates of the model on the disk. Set to 0 to never save the "
@ -28,10 +27,14 @@ if __name__ == "__main__":
parser.add_argument("--hparams", default="", parser.add_argument("--hparams", default="",
help="Hyperparameter overrides as a comma-separated list of name=value " help="Hyperparameter overrides as a comma-separated list of name=value "
"pairs") "pairs")
args = parser.parse_args() args, _ = parser.parse_known_args()
print_args(args, parser) print_args(args, parser)
args.hparams = hparams.parse(args.hparams) args.hparams = hparams.parse(args.hparams)
# Run the training # Run the training
train(**vars(args)) train(**vars(args))
if __name__ == "__main__":
new_train()

View File

@ -0,0 +1,66 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@ -46,15 +46,16 @@ else:
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.") raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
class Input(BaseModel): class Input(BaseModel):
message: str = Field( message: str = Field(
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容" ..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
) )
local_audio_file: audio_input_selection = Field( local_audio_file: audio_input_selection = Field(
..., alias="输入语音本地wav", ..., alias="选择语音本地wav",
description="选择本地语音文件." description="选择本地语音文件."
) )
record_audio_file: FileContent = Field(default=None, alias="录制语音",
description="录音.", is_recorder=True, mime_type="audio/wav")
upload_audio_file: FileContent = Field(default=None, alias="或上传语音", upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
description="拖拽或点击上传.", mime_type="audio/wav") description="拖拽或点击上传.", mime_type="audio/wav")
encoder: encoders = Field( encoder: encoders = Field(
@ -104,7 +105,12 @@ def synthesize(input: Input) -> Output:
gan_vocoder.load_model(Path(input.vocoder.value)) gan_vocoder.load_model(Path(input.vocoder.value))
# load file # load file
if input.upload_audio_file != None: if input.record_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.record_audio_file.as_bytes())
f.seek(0)
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
elif input.upload_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f: with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.upload_audio_file.as_bytes()) f.write(input.upload_audio_file.as_bytes())
f.seek(0) f.seek(0)

View File

@ -37,6 +37,12 @@ def is_single_file_property(property: Dict) -> bool:
# TODO: binary? # TODO: binary?
return property.get("format") == "byte" return property.get("format") == "byte"
def is_single_autio_property(property: Dict) -> bool:
if property.get("type") != "string":
return False
# TODO: binary?
return property.get("format") == "bytes"
def is_single_directory_property(property: Dict) -> bool: def is_single_directory_property(property: Dict) -> bool:
if property.get("type") != "string": if property.get("type") != "string":

View File

@ -242,7 +242,14 @@ class InputUI:
file_extension = None file_extension = None
if "mime_type" in property: if "mime_type" in property:
file_extension = mimetypes.guess_extension(property["mime_type"]) file_extension = mimetypes.guess_extension(property["mime_type"])
if "is_recorder" in property:
from audio_recorder_streamlit import audio_recorder
audio_bytes = audio_recorder()
if audio_bytes:
streamlit_app.audio(audio_bytes, format="audio/wav")
return audio_bytes
uploaded_file = streamlit_app.file_uploader( uploaded_file = streamlit_app.file_uploader(
**streamlit_kwargs, accept_multiple_files=False, type=file_extension **streamlit_kwargs, accept_multiple_files=False, type=file_extension
) )
@ -262,6 +269,39 @@ class InputUI:
streamlit_app.video(bytes, format=property.get("mime_type")) streamlit_app.video(bytes, format=property.get("mime_type"))
return bytes return bytes
def _render_single_audio_input(
self, streamlit_app: st, key: str, property: Dict
) -> Any:
# streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
from audio_recorder_streamlit import audio_recorder
audio_bytes = audio_recorder()
if audio_bytes:
streamlit_app.audio(audio_bytes, format="audio/wav")
return audio_bytes
# file_extension = None
# if "mime_type" in property:
# file_extension = mimetypes.guess_extension(property["mime_type"])
# uploaded_file = streamlit_app.file_uploader(
# **streamlit_kwargs, accept_multiple_files=False, type=file_extension
# )
# if uploaded_file is None:
# return None
# bytes = uploaded_file.getvalue()
# if property.get("mime_type"):
# if is_compatible_audio(property["mime_type"]):
# # Show audio
# streamlit_app.audio(bytes, format=property.get("mime_type"))
# if is_compatible_image(property["mime_type"]):
# # Show image
# streamlit_app.image(bytes)
# if is_compatible_video(property["mime_type"]):
# # Show video
# streamlit_app.video(bytes, format=property.get("mime_type"))
# return bytes
def _render_single_string_input( def _render_single_string_input(
self, streamlit_app: st, key: str, property: Dict self, streamlit_app: st, key: str, property: Dict
) -> Any: ) -> Any:
@ -820,7 +860,6 @@ def getOpyrator(mode: str) -> Opyrator:
from control.mkgui.app import synthesize from control.mkgui.app import synthesize
return Opyrator(synthesize) return Opyrator(synthesize)
def render_streamlit_ui() -> None: def render_streamlit_ui() -> None:
# init # init
session_state = st.session_state session_state = st.session_state
@ -852,6 +891,13 @@ def render_streamlit_ui() -> None:
with left: with left:
st.header("Control 控制") st.header("Control 控制")
# if session_state.mode in ["AI拟音", "VC拟音"] :
# from audiorecorder import audiorecorder
# audio = audiorecorder("Click to record", "Recording...")
# if len(audio) > 0:
# # To play audio in frontend:
# st.audio(audio.tobytes())
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st) InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
execute_selected = st.button(opyrator.action) execute_selected = st.button(opyrator.action)
if execute_selected: if execute_selected:

View File

@ -38,7 +38,8 @@ recognized_datasets = [
"VoxCeleb2/dev/aac", "VoxCeleb2/dev/aac",
"VoxCeleb2/test/aac", "VoxCeleb2/test/aac",
"VCTK-Corpus/wav48", "VCTK-Corpus/wav48",
"aidatatang_200zh/corpus", "aidatatang_200zh/corpus/test",
"aidatatang_200zh/corpus/train",
"aishell3/test/wav", "aishell3/test/wav",
"magicdata/train", "magicdata/train",
] ]

View File

@ -3,7 +3,6 @@ from PyQt5 import QtGui
from PyQt5.QtWidgets import * from PyQt5.QtWidgets import *
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from models.encoder.inference import plot_embedding_as_heatmap from models.encoder.inference import plot_embedding_as_heatmap
from control.toolbox.utterance import Utterance from control.toolbox.utterance import Utterance
from pathlib import Path from pathlib import Path

View File

@ -17,15 +17,15 @@ if __name__ == '__main__':
"supported datasets.", default=None) "supported datasets.", default=None)
parser.add_argument("-vc", "--vc_mode", action="store_true", parser.add_argument("-vc", "--vc_mode", action="store_true",
help="Voice Conversion Mode(PPG based)") help="Voice Conversion Mode(PPG based)")
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models", parser.add_argument("-e", "--enc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}encoder",
help="Directory containing saved encoder models") help="Directory containing saved encoder models")
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models", parser.add_argument("-s", "--syn_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}synthesizer",
help="Directory containing saved synthesizer models") help="Directory containing saved synthesizer models")
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models", parser.add_argument("-v", "--voc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}vocoder",
help="Directory containing saved vocoder models") help="Directory containing saved vocoder models")
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models", parser.add_argument("-ex", "--extractor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg_extractor",
help="Directory containing saved extrator models") help="Directory containing saved extrator models")
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models", parser.add_argument("-cv", "--convertor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg2mel",
help="Directory containing saved convert models") help="Directory containing saved convert models")
parser.add_argument("--cpu", action="store_true", help=\ parser.add_argument("--cpu", action="store_true", help=\
"If True, processing is done on CPU, even when a GPU is available.") "If True, processing is done on CPU, even when a GPU is available.")

View File

@ -3,7 +3,7 @@ import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from models.vocoder.fregan.utils import init_weights, get_padding from utils.util import init_weights, get_padding
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1

View File

@ -27,21 +27,12 @@ def plot_spectrogram(spectrogram):
return fig return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m): def apply_weight_norm(m):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
weight_norm(m) weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def load_checkpoint(filepath, device): def load_checkpoint(filepath, device):
assert os.path.isfile(filepath) assert os.path.isfile(filepath)

View File

@ -3,7 +3,7 @@ import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from models.vocoder.hifigan.utils import init_weights, get_padding from utils.util import init_weights, get_padding
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1

View File

@ -6,7 +6,6 @@ from torch.nn.utils import weight_norm
matplotlib.use("Agg") matplotlib.use("Agg")
import matplotlib.pylab as plt import matplotlib.pylab as plt
def plot_spectrogram(spectrogram): def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2)) fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", im = ax.imshow(spectrogram, aspect="auto", origin="lower",
@ -19,12 +18,6 @@ def plot_spectrogram(spectrogram):
return fig return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m): def apply_weight_norm(m):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
@ -55,4 +48,3 @@ def scan_checkpoint(cp_dir, prefix):
if len(cp_list) == 0: if len(cp_list) == 0:
return None return None
return sorted(cp_list)[-1] return sorted(cp_list)[-1]

View File

@ -1,7 +1,7 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
from pathlib import Path from pathlib import Path
from models.vocoder.wavernn import audio from models.vocoder.wavernn import audio
import vocoder.wavernn.hparams as hp import models.vocoder.wavernn.hparams as hp
import numpy as np import numpy as np
import torch import torch

View File

@ -1,7 +1,7 @@
import math import math
import numpy as np import numpy as np
import librosa import librosa
import vocoder.wavernn.hparams as hp import models.vocoder.wavernn.hparams as hp
from scipy.signal import lfilter from scipy.signal import lfilter
import soundfile as sf import soundfile as sf

View File

@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from pathlib import Path from pathlib import Path
from torch import optim from torch import optim
import torch.nn.functional as F import torch.nn.functional as F
import vocoder.wavernn.hparams as hp import models.vocoder.wavernn.hparams as hp
import numpy as np import numpy as np
import time import time
import torch import torch

View File

3
pre.py
View File

@ -16,6 +16,7 @@ recognized_datasets = [
"data_aishell" "data_aishell"
] ]
#TODO: add for emotional data
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Preprocesses audio files from datasets, encodes them as mel spectrograms " description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
@ -42,7 +43,7 @@ if __name__ == "__main__":
(these are used to split long audio files into sub-utterances.)") (these are used to split long audio files into sub-utterances.)")
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\ parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.") "Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.")
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\ parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="data/ckpt/encoder/pretrained.pt", help=\
"Path your trained encoder model.") "Path your trained encoder model.")
parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\ parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\
"Number of processes in parallel.An encoder is created for each, so you may need to lower " "Number of processes in parallel.An encoder is created for each, so you may need to lower "

View File

@ -25,4 +25,6 @@ streamlit==1.8.0
PyYAML==5.4.1 PyYAML==5.4.1
torch_complex torch_complex
espnet espnet
PyWavelets PyWavelets
monotonic-align==0.0.3
transformers==4.26.0

View File

@ -5,63 +5,18 @@ import numpy as np
from utils.load_yaml import HpsYaml from utils.load_yaml import HpsYaml
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main(): def main():
# Arguments # Arguments
parser = argparse.ArgumentParser(description= preparser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.') 'Training model.')
parser.add_argument('--config', type=str, preparser.add_argument('--type', type=str,
help='Path to experiment config, e.g., config/vc.yaml') help='type of training ')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
### ###
paras, _ = preparser.parse_known_args()
paras = parser.parse_args() if paras.type == "synth":
setattr(paras, 'gpu', not paras.cpu) from control.cli.synthesizer_train import new_train
setattr(paras, 'pin_memory', not paras.no_pin) new_train()
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,4 +1,7 @@
import matplotlib import matplotlib
from torch.nn import functional as F
import torch
matplotlib.use('Agg') matplotlib.use('Agg')
import time import time
@ -48,3 +51,78 @@ class AttrDict(dict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs) super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self self.__dict__ = self
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
return path
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result