diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 44e047d..7190704 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -296,8 +296,8 @@ class AMPModel(ModelBase): self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() - if random_warp: - self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() + #if random_warp: + # self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)