AMP: don't train inters with random_warp-off

This commit is contained in:
iperov 2021-09-03 13:50:31 +04:00
parent 3fe8ce86b1
commit b1990d421a

View File

@ -294,7 +294,10 @@ class AMPModel(ModelBase):
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0
self.G_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
if random_warp:
self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights()
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)