mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
AMP: don't train inters with random_warp-off
This commit is contained in:
parent
3fe8ce86b1
commit
b1990d421a
|
@ -122,9 +122,9 @@ class AMPModel(ModelBase):
|
|||
morph_factor = self.options['morph_factor']
|
||||
gan_power = self.gan_power = self.options['gan_power']
|
||||
random_warp = self.options['random_warp']
|
||||
|
||||
|
||||
blur_out_mask = self.options['blur_out_mask']
|
||||
|
||||
|
||||
ct_mode = self.options['ct_mode']
|
||||
if ct_mode == 'none':
|
||||
ct_mode = None
|
||||
|
@ -294,7 +294,10 @@ class AMPModel(ModelBase):
|
|||
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
|
||||
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0
|
||||
|
||||
self.G_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
|
||||
self.G_weights = self.encoder.get_weights() + self.decoder.get_weights()
|
||||
|
||||
if random_warp:
|
||||
self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights()
|
||||
|
||||
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
|
||||
self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu)
|
||||
|
@ -352,64 +355,64 @@ class AMPModel(ModelBase):
|
|||
|
||||
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
|
||||
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
|
||||
|
||||
inter_dims_bin = int(inter_dims*morph_factor)
|
||||
with tf.device(f'/CPU:0'):
|
||||
|
||||
inter_dims_bin = int(inter_dims*morph_factor)
|
||||
with tf.device(f'/CPU:0'):
|
||||
inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )),
|
||||
tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0)
|
||||
|
||||
|
||||
inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None])
|
||||
|
||||
|
||||
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
|
||||
gpu_dst_code = gpu_dst_inter_dst_code
|
||||
|
||||
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
|
||||
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
|
||||
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 )
|
||||
|
||||
|
||||
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
|
||||
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
|
||||
|
||||
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
|
||||
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
|
||||
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
|
||||
|
||||
gpu_target_srcm_anti = 1-gpu_target_srcm
|
||||
gpu_target_dstm_anti = 1-gpu_target_dstm
|
||||
|
||||
|
||||
gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32)
|
||||
gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32)
|
||||
|
||||
|
||||
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2
|
||||
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2
|
||||
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
|
||||
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
|
||||
|
||||
|
||||
if blur_out_mask:
|
||||
#gpu_target_src = gpu_target_src*gpu_target_srcm_blur + nn.gaussian_blur(gpu_target_src, resolution // 32)*gpu_target_srcm_anti_blur
|
||||
#gpu_target_dst = gpu_target_dst*gpu_target_dstm_blur + nn.gaussian_blur(gpu_target_dst, resolution // 32)*gpu_target_dstm_anti_blur
|
||||
bg_blur_div = 128
|
||||
|
||||
|
||||
gpu_target_src = gpu_target_src*gpu_target_srcm + \
|
||||
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div),
|
||||
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div),
|
||||
(1-nn.gaussian_blur(gpu_target_srcm, resolution / bg_blur_div) ) ) * gpu_target_srcm_anti
|
||||
|
||||
gpu_target_dst = gpu_target_dst*gpu_target_dstm + \
|
||||
tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, resolution / bg_blur_div),
|
||||
(1-nn.gaussian_blur(gpu_target_dstm, resolution / bg_blur_div)) ) * gpu_target_dstm_anti
|
||||
|
||||
|
||||
|
||||
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
|
||||
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
|
||||
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
|
||||
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
|
||||
|
||||
|
||||
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
|
||||
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
|
||||
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
|
||||
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
|
||||
|
||||
|
||||
# Structural loss
|
||||
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||
|
@ -434,8 +437,8 @@ class AMPModel(ModelBase):
|
|||
# dst-dst background weak loss
|
||||
gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
|
||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
|
||||
|
||||
|
||||
|
||||
|
||||
if gan_power != 0:
|
||||
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
|
||||
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
|
||||
|
@ -453,7 +456,7 @@ class AMPModel(ModelBase):
|
|||
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
|
||||
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
|
||||
) * gan_power
|
||||
|
||||
|
||||
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
|
||||
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
|
||||
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
|
||||
|
@ -473,7 +476,7 @@ class AMPModel(ModelBase):
|
|||
src_loss = tf.concat(gpu_src_losses, 0)
|
||||
dst_loss = tf.concat(gpu_dst_losses, 0)
|
||||
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
|
||||
|
||||
|
||||
if gan_power != 0:
|
||||
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) )
|
||||
|
||||
|
@ -554,8 +557,8 @@ class AMPModel(ModelBase):
|
|||
dst_generators_count = cpu_count // 2
|
||||
if ct_mode is not None:
|
||||
src_generators_count = int(src_generators_count * 1.5)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
|
|
Loading…
Reference in New Issue
Block a user