AMP: removed eyes mouth prio option(default enabled), removed masked training option(default enabled).

This commit is contained in:
iperov 2021-07-17 21:58:57 +04:00
parent 959a3530f8
commit 0748b8d043

View File

@ -18,44 +18,26 @@ class AMPModel(ModelBase):
def on_initialize_options(self):
device_config = nn.getCurrentDeviceConfig()
lowest_vram = 2
if len(device_config.devices) != 0:
lowest_vram = device_config.devices.get_worst_device().total_mem_gb
if lowest_vram >= 4:
suggest_batch_size = 8
else:
suggest_batch_size = 4
yn_str = {True:'y',False:'n'}
min_res = 64
max_res = 640
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
inter_dims = self.load_or_def_option('inter_dims', None)
if inter_dims is None:
inter_dims = self.options['ae_dims']
default_inter_dims = self.options['inter_dims'] = inter_dims
default_inter_dims = self.options['inter_dims'] = inter_dims
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None)
default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None)
default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5)
default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True)
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', True)
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
#default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
ask_override = self.ask_override()
if self.is_first_run() or ask_override:
@ -64,12 +46,11 @@ class AMPModel(ModelBase):
self.ask_target_iter()
self.ask_random_src_flip()
self.ask_random_dst_flip()
self.ask_batch_size(suggest_batch_size)
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters. Lr-dropout should be enabled.')
self.ask_batch_size(8)
if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
resolution = np.clip ( (resolution // 32) * 32, min_res, max_res)
resolution = np.clip ( (resolution // 32) * 32, 64, 640)
self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower()
@ -93,15 +74,10 @@ class AMPModel(ModelBase):
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="The smaller the value, the more src-like facial expressions will appear. The larger the value, the less space there is to train a large dst faceset in the neural network. Typical fine value is 0.33"), 0.1, 0.5 )
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 )
self.options['morph_factor'] = morph_factor
if self.is_first_run() or ask_override:
if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head':
self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.")
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
@ -125,10 +101,7 @@ class AMPModel(ModelBase):
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
#self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y")
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
#self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
#override
def on_initialize(self):
@ -138,30 +111,38 @@ class AMPModel(ModelBase):
nn.initialize(data_format=self.model_data_format)
tf = nn.tf
self.resolution = resolution = self.options['resolution']
input_ch=3
ae_dims = self.ae_dims = self.options['ae_dims']
inter_dims = self.inter_dims = self.options['inter_dims']
e_dims = self.options['e_dims']
d_dims = self.options['d_dims']
resolution = self.resolution = self.options['resolution']
e_dims = self.options['e_dims']
ae_dims = self.options['ae_dims']
inter_dims = self.inter_dims = self.options['inter_dims']
inter_res = self.inter_res = resolution // 32
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
inter_res = self.inter_res = resolution // 32
use_fp16 = False# self.options['use_fp16']
ae_use_fp16 = use_fp16
ae_conv_dtype = tf.float16 if use_fp16 else tf.float32
face_type = self.face_type = {'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ]
morph_factor = self.options['morph_factor']
gan_power = self.gan_power = self.options['gan_power']
random_warp = self.options['random_warp']
ct_mode = self.options['ct_mode']
if ct_mode == 'none':
ct_mode = None
use_fp16 = self.is_exporting
conv_dtype = tf.float16 if use_fp16 else tf.float32
class Downscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=5 ):
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=ae_conv_dtype)
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
def forward(self, x):
return tf.nn.leaky_relu(self.conv1(x), 0.1)
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, x):
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
@ -169,8 +150,8 @@ class AMPModel(ModelBase):
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
def forward(self, inp):
x = self.conv1(inp)
@ -191,7 +172,7 @@ class AMPModel(ModelBase):
self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )
def forward(self, x):
if ae_use_fp16:
if use_fp16:
x = tf.cast(x, tf.float16)
x = self.down1(x)
x = self.res1(x)
@ -200,7 +181,7 @@ class AMPModel(ModelBase):
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
if ae_use_fp16:
if use_fp16:
x = tf.cast(x, tf.float32)
x = nn.pixel_norm(nn.flatten(x), axes=-1)
x = self.dense1(x)
@ -235,15 +216,15 @@ class AMPModel(ModelBase):
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=ae_conv_dtype)
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=ae_conv_dtype)
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
def forward(self, z):
if ae_use_fp16:
if use_fp16:
z = tf.cast(z, tf.float16)
x = self.upscale0(z)
@ -259,60 +240,22 @@ class AMPModel(ModelBase):
self.out_conv1(x),
self.out_conv2(x),
self.out_conv3(x)), nn.conv2d_ch_axis), 2) )
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
m = self.upscalem3(m)
m = self.upscalem4(m)
m = tf.nn.sigmoid(self.out_convm(m))
if ae_use_fp16:
if use_fp16:
x = tf.cast(x, tf.float32)
m = tf.cast(m, tf.float32)
return x, m
self.face_type = {'f' : FaceType.FULL,
'wf' : FaceType.WHOLE_FACE,
'head' : FaceType.HEAD}[ self.options['face_type'] ]
if 'eyes_prio' in self.options:
self.options.pop('eyes_prio')
eyes_mouth_prio = self.options['eyes_mouth_prio']
morph_factor = self.options['morph_factor']
gan_power = self.gan_power = self.options['gan_power']
random_warp = self.options['random_warp']
random_src_flip = self.random_src_flip
random_dst_flip = self.random_dst_flip
#pretrain = self.pretrain = self.options['pretrain']
#if self.pretrain_just_disabled:
# self.set_iter(0)
# self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
# random_warp = False if self.pretrain else self.options['random_warp']
# random_src_flip = self.random_src_flip if not self.pretrain else True
# random_dst_flip = self.random_dst_flip if not self.pretrain else True
# if self.pretrain:
# self.options_show_override['gan_power'] = 0.0
# self.options_show_override['random_warp'] = False
# self.options_show_override['lr_dropout'] = 'n'
# self.options_show_override['uniform_yaw'] = True
masked_training = self.options['masked_training']
ct_mode = self.options['ct_mode']
if ct_mode == 'none':
ct_mode = None
models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu']
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = []
@ -333,7 +276,6 @@ class AMPModel(ModelBase):
self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t')
# Initializing model classes
with tf.device (models_opt_device):
self.encoder = Encoder(name='encoder')
self.inter_src = Inter(name='inter_src')
@ -346,30 +288,21 @@ class AMPModel(ModelBase):
[self.decoder , 'decoder.npy'] ]
if self.is_training:
if gan_power != 0:
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
self.model_filename_list += [ [self.GAN, 'GAN.npy'] ]
# Initialize optimizers
lr=5e-5
lr_dropout = 0.3
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
self.all_weights = self.encoder.get_weights() + self.decoder.get_weights()
#if pretrain:
# self.trainable_weights = self.encoder.get_weights() + self.decoder.get_weights()
#else:
self.trainable_weights = self.encoder.get_weights() + self.decoder.get_weights()
self.src_dst_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=0.3, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
if gan_power != 0:
self.GAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt')
self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)#+self.D_src_x2.get_weights()
self.model_filename_list += [ (self.GAN_opt, 'GAN_opt.npy') ]
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=0.3, clipnorm=clipnorm, name='GAN_opt')
self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)
self.model_filename_list += [ [self.GAN, 'GAN.npy'],
[self.GAN_opt, 'GAN_opt.npy'] ]
if self.is_training:
# Adjust batch size for multiple GPU
@ -387,8 +320,8 @@ class AMPModel(ModelBase):
gpu_src_losses = []
gpu_dst_losses = []
gpu_G_loss_gvs = []
gpu_D_src_dst_loss_gvs = []
gpu_G_loss_gradients = []
gpu_GAN_loss_grads = []
for gpu_id in range(gpu_count):
with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
@ -408,16 +341,8 @@ class AMPModel(ModelBase):
gpu_src_code = self.encoder (gpu_warped_src)
gpu_dst_code = self.encoder (gpu_warped_dst)
# if pretrain:
# gpu_src_inter_src_code = self.inter_src (gpu_src_code)
# gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
# gpu_src_code = gpu_src_inter_src_code * nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor)
# gpu_dst_code = gpu_src_dst_code = gpu_dst_inter_dst_code * nn.random_binomial( [bs_per_gpu, gpu_dst_inter_dst_code.shape.as_list()[1], 1,1] , p=morph_factor)
# else:
gpu_src_inter_src_code = self.inter_src (gpu_src_code)
gpu_src_inter_dst_code = self.inter_dst (gpu_src_code)
gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code)
gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code)
inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor)
gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial)
@ -431,62 +356,50 @@ class AMPModel(ModelBase):
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst)
gpu_pred_src_dst_list.append(gpu_pred_src_dst)
gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_pred_src_srcm_list.append(gpu_pred_src_srcm)
gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm)
gpu_pred_src_dstm_list.append(gpu_pred_src_dstm)
gpu_target_srcm_blur = tf.clip_by_value( nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ), 0, 0.5) * 2
gpu_target_dstm_blur = tf.clip_by_value(nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ), 0, 0.5) * 2
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
gpu_target_dst_anti_masked = gpu_target_dst*(1.0-gpu_target_dstm_blur)
gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur)
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
gpu_target_dst_masked_opt = gpu_target_dst*gpu_target_dstm_blur if masked_training else gpu_target_dst
gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur
gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur)
# Structural loss
gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
if resolution < 256:
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
else:
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
if eyes_mouth_prio:
gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
gpu_dst_loss += 0.1*tf.reduce_mean(tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
gpu_dst_losses += [gpu_dst_loss]
# Pixel loss
gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3])
#if not pretrain:
if resolution < 256:
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
else:
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if eyes_mouth_prio:
gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3])
# Eyes+mouth prio loss
gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3])
# Mask loss
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
#else:
# gpu_src_loss = gpu_dst_loss
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )
# dst-dst background weak loss
gpu_dst_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] )
gpu_dst_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked)
gpu_src_losses += [gpu_src_loss]
#if pretrain:
# gpu_G_loss = gpu_dst_loss
#else:
gpu_dst_losses += [gpu_dst_loss]
gpu_G_loss = gpu_src_loss + gpu_dst_loss
def DLossOnes(logits):
@ -496,30 +409,28 @@ class AMPModel(ModelBase):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3])
if gan_power != 0:
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked_opt)
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked_opt)
gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked_opt)
gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked_opt)
gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked)
gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked)
gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked)
gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked)
gpu_D_src_dst_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \
DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
) * ( 1.0 / 8)
gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \
DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \
DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \
DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2)
) * (1.0 / 8)
gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.GAN.get_weights() ) ]
gpu_GAN_loss_grads += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ]
gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \
DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2)
) * gan_power
if masked_training:
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.trainable_weights ) ]
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.encoder.get_weights() + self.decoder.get_weights() ) ]
# Average losses and gradients, and create optimizer update ops
with tf.device(f'/CPU:0'):
@ -533,15 +444,15 @@ class AMPModel(ModelBase):
with tf.device (models_opt_device):
src_loss = tf.concat(gpu_src_losses, 0)
dst_loss = tf.concat(gpu_dst_losses, 0)
src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs))
train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients))
if gan_power != 0:
src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )
GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_grads) )
# Initializing training and view functions
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
def train(warped_src, target_src, target_srcm, target_srcm_em, \
warped_dst, target_dst, target_dstm, target_dstm_em, ):
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op],
feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
@ -552,20 +463,20 @@ class AMPModel(ModelBase):
self.target_dstm_em:target_dstm_em,
})
return s, d
self.src_dst_train = src_dst_train
self.train = train
if gan_power != 0:
def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
warped_dst, target_dst, target_dstm, target_dstm_em, ):
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.target_srcm_em:target_srcm_em,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm,
self.target_dstm_em:target_dstm_em})
self.D_src_dst_train = D_src_dst_train
def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \
warped_dst, target_dst, target_dstm, target_dstm_em, ):
nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.target_srcm_em:target_srcm_em,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm,
self.target_dstm_em:target_dstm_em})
self.GAN_train = GAN_train
def AE_view(warped_src, warped_dst, morph_value):
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
@ -576,8 +487,8 @@ class AMPModel(ModelBase):
#Initializing merge function
with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'):
gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code)
gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code)
gpu_dst_inter_src_code = self.inter_src (gpu_dst_code)
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]),
@ -593,16 +504,10 @@ class AMPModel(ModelBase):
# Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
# if self.pretrain_just_disabled:
# do_init = False
# if model == self.inter_src or model == self.inter_dst:
# do_init = True
# else:
do_init = self.is_first_run()
if self.is_training and gan_power != 0 and model == self.GAN:
if self.gan_model_changed:
do_init = True
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init:
@ -614,7 +519,7 @@ class AMPModel(ModelBase):
training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path()
training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path()
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain
cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2
@ -624,21 +529,21 @@ class AMPModel(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=random_src_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
generators_count=src_generators_count ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain,
generators_count=dst_generators_count )
@ -646,15 +551,12 @@ class AMPModel(ModelBase):
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
#if self.pretrain_just_disabled:
# self.update_sample_for_preview(force_new=True)
def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm')
io.log_info(f'Dumping .dfm to {output_path}')
tf = nn.tf
with tf.device (nn.tf_default_device_name):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
@ -679,13 +581,13 @@ class AMPModel(ModelBase):
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_face_mask','out_celeb_face','out_celeb_face_mask']
)
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
@ -717,7 +619,7 @@ class AMPModel(ModelBase):
( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (src_loss[i], warped_src[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
@ -737,12 +639,12 @@ class AMPModel(ModelBase):
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.gan_power != 0:
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )