mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Add code to control finetune layers (#154)
This commit is contained in:
parent
31bc6656c3
commit
724194a4de
|
@ -62,9 +62,11 @@ hparams = HParams(
|
||||||
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
||||||
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
||||||
# Set to -1 to generate after completing epoch, or 0 to disable
|
# Set to -1 to generate after completing epoch, or 0 to disable
|
||||||
|
|
||||||
tts_eval_num_samples = 1, # Makes this number of samples
|
tts_eval_num_samples = 1, # Makes this number of samples
|
||||||
|
|
||||||
|
## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj
|
||||||
|
tts_finetune_layers = [],
|
||||||
|
|
||||||
### Data Preprocessing
|
### Data Preprocessing
|
||||||
max_mel_frames = 900,
|
max_mel_frames = 900,
|
||||||
rescale = True,
|
rescale = True,
|
||||||
|
|
|
@ -496,6 +496,15 @@ class Tacotron(nn.Module):
|
||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
def finetune_partial(self, whitelist_layers):
|
||||||
|
self.zero_grad()
|
||||||
|
for name, child in self.named_children():
|
||||||
|
if name in whitelist_layers:
|
||||||
|
print("Trainable Layer: %s" % name)
|
||||||
|
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||||
|
for param in child.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def get_step(self):
|
def get_step(self):
|
||||||
return self.step.data.item()
|
return self.step.data.item()
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
||||||
|
|
||||||
# Initialize the optimizer
|
# Initialize the optimizer
|
||||||
optimizer = optim.Adam(model.parameters())
|
optimizer = optim.Adam(model.parameters(), amsgrad=True)
|
||||||
|
|
||||||
# Load the weights
|
# Load the weights
|
||||||
if force_restart or not weights_fpath.exists():
|
if force_restart or not weights_fpath.exists():
|
||||||
|
@ -146,7 +146,6 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model.r = r
|
model.r = r
|
||||||
|
|
||||||
# Begin the training
|
# Begin the training
|
||||||
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
||||||
("Batch Size", batch_size),
|
("Batch Size", batch_size),
|
||||||
|
@ -155,6 +154,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
|
|
||||||
for p in optimizer.param_groups:
|
for p in optimizer.param_groups:
|
||||||
p["lr"] = lr
|
p["lr"] = lr
|
||||||
|
if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0:
|
||||||
|
model.finetune_partial(hparams.tts_finetune_layers)
|
||||||
|
|
||||||
data_loader = DataLoader(dataset,
|
data_loader = DataLoader(dataset,
|
||||||
collate_fn=collate_synthesizer,
|
collate_fn=collate_synthesizer,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user