AMP: don't train inters with random_warp-off

This commit is contained in:
iperov 2021-09-03 13:50:31 +04:00
parent 3fe8ce86b1
commit b1990d421a

View File

@ -122,9 +122,9 @@ class AMPModel(ModelBase):
morph_factor = self.options['morph_factor']
gan_power = self.gan_power = self.options['gan_power']
random_warp = self.options['random_warp']
blur_out_mask = self.options['blur_out_mask']
ct_mode = self.options['ct_mode']
if ct_mode == 'none':
ct_mode = None
@ -294,7 +294,10 @@ class AMPModel(ModelBase):
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0
self.G_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
if random_warp:
self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights()
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
@ -352,64 +355,64 @@ class AMPModel(ModelBase):
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
inter_dims_bin = int(inter_dims*morph_factor)
with tf.device(f'/CPU:0'):
inter_dims_bin = int(inter_dims*morph_factor)
with tf.device(f'/CPU:0'):
inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0)
inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
gpu_dst_code = gpu_dst_inter_dst_code
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_anti = 1-gpu_target_srcm
gpu_target_dstm_anti = 1-gpu_target_dstm
gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
if blur_out_mask:
#gpu_target_src = gpu_target_src*gpu_target_srcm_blur + nn.gaussian_blur(gpu_target_src, resolution // 32)*gpu_target_srcm_anti_blur
#gpu_target_dst = gpu_target_dst*gpu_target_dstm_blur + nn.gaussian_blur(gpu_target_dst, resolution // 32)*gpu_target_dstm_anti_blur
bg_blur_div = 128
gpu_target_src = gpu_target_src*gpu_target_srcm + \
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div),
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div),
(1-nn.gaussian_blur(gpu_target_srcm, resolution / bg_blur_div) ) ) * gpu_target_srcm_anti
gpu_target_dst = gpu_target_dst*gpu_target_dstm + \
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, resolution / bg_blur_div),
(1-nn.gaussian_blur(gpu_target_dstm, resolution / bg_blur_div)) ) * gpu_target_dstm_anti
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
# Structural loss
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
@ -434,8 +437,8 @@ class AMPModel(ModelBase):
# dst-dst background weak loss
gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
if gan_power != 0:
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
@ -453,7 +456,7 @@ class AMPModel(ModelBase):
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
) * gan_power
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
@ -473,7 +476,7 @@ class AMPModel(ModelBase):
src_loss = tf.concat(gpu_src_losses, 0)
dst_loss = tf.concat(gpu_dst_losses, 0)
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
if gan_power != 0:
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )
@ -554,8 +557,8 @@ class AMPModel(ModelBase):
dst_generators_count = cpu_count // 2
if ct_mode is not None:
src_generators_count = int(src_generators_count * 1.5)
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),