mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Add code to control finetune layers (#154)
This commit is contained in:
parent
31bc6656c3
commit
724194a4de
|
@ -62,9 +62,11 @@ hparams = HParams(
|
|||
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
||||
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
||||
# Set to -1 to generate after completing epoch, or 0 to disable
|
||||
|
||||
tts_eval_num_samples = 1, # Makes this number of samples
|
||||
|
||||
## For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj
|
||||
tts_finetune_layers = [],
|
||||
|
||||
### Data Preprocessing
|
||||
max_mel_frames = 900,
|
||||
rescale = True,
|
||||
|
|
|
@ -496,6 +496,15 @@ class Tacotron(nn.Module):
|
|||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = optim.Adam(model.parameters())
|
||||
optimizer = optim.Adam(model.parameters(), amsgrad=True)
|
||||
|
||||
# Load the weights
|
||||
if force_restart or not weights_fpath.exists():
|
||||
|
@ -146,7 +146,6 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||
continue
|
||||
|
||||
model.r = r
|
||||
|
||||
# Begin the training
|
||||
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
||||
("Batch Size", batch_size),
|
||||
|
@ -155,6 +154,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||
|
||||
for p in optimizer.param_groups:
|
||||
p["lr"] = lr
|
||||
if hparams.tts_finetune_layers is not None and len(hparams.tts_finetune_layers) > 0:
|
||||
model.finetune_partial(hparams.tts_finetune_layers)
|
||||
|
||||
data_loader = DataLoader(dataset,
|
||||
collate_fn=collate_synthesizer,
|
||||
|
|
Loading…
Reference in New Issue
Block a user