mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
68 lines
2.9 KiB
Python
68 lines
2.9 KiB
Python
import sys
|
|
import torch
|
|
import argparse
|
|
import numpy as np
|
|
from utils.load_yaml import HpsYaml
|
|
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
|
|
|
# For reproducibility, comment these may speed up training
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
def main():
|
|
# Arguments
|
|
parser = argparse.ArgumentParser(description=
|
|
'Training PPG2Mel VC model.')
|
|
parser.add_argument('--config', type=str,
|
|
help='Path to experiment config, e.g., config/vc.yaml')
|
|
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
|
parser.add_argument('--logdir', default='log/', type=str,
|
|
help='Logging path.', required=False)
|
|
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
|
help='Checkpoint path.', required=False)
|
|
parser.add_argument('--outdir', default='result/', type=str,
|
|
help='Decode output path.', required=False)
|
|
parser.add_argument('--load', default=None, type=str,
|
|
help='Load pre-trained model (for training only)', required=False)
|
|
parser.add_argument('--warm_start', action='store_true',
|
|
help='Load model weights only, ignore specified layers.')
|
|
parser.add_argument('--seed', default=0, type=int,
|
|
help='Random seed for reproducable results.', required=False)
|
|
parser.add_argument('--njobs', default=8, type=int,
|
|
help='Number of threads for dataloader/decoding.', required=False)
|
|
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
|
parser.add_argument('--no-pin', action='store_true',
|
|
help='Disable pin-memory for dataloader')
|
|
parser.add_argument('--test', action='store_true', help='Test the model.')
|
|
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
|
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
|
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
|
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
|
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
|
|
|
###
|
|
|
|
paras = parser.parse_args()
|
|
setattr(paras, 'gpu', not paras.cpu)
|
|
setattr(paras, 'pin_memory', not paras.no_pin)
|
|
setattr(paras, 'verbose', not paras.no_msg)
|
|
# Make the config dict dot visitable
|
|
config = HpsYaml(paras.config)
|
|
|
|
np.random.seed(paras.seed)
|
|
torch.manual_seed(paras.seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(paras.seed)
|
|
|
|
print(">>> OneShot VC training ...")
|
|
mode = "train"
|
|
solver = Solver(config, paras, mode)
|
|
solver.load_data()
|
|
solver.set_model()
|
|
solver.exec()
|
|
print(">>> Oneshot VC train finished!")
|
|
sys.exit(0)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|