SAEHD: removed 'dst_denoise' option. Added -t arhi option.

This commit is contained in:
iperov 2021-08-29 11:48:55 +04:00
parent 01f1a084b4
commit 6e094d873d
2 changed files with 136 additions and 73 deletions

View File

@ -7,6 +7,10 @@ class DeepFakeArchi(nn.ArchiBase):
mod None - default
'quick'
opts ''
''
't'
"""
def __init__(self, resolution, use_fp16=False, mod=None, opts=None):
super().__init__()
@ -16,7 +20,7 @@ class DeepFakeArchi(nn.ArchiBase):
conv_dtype = tf.float16 if use_fp16 else tf.float32
if mod is None:
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
@ -79,21 +83,44 @@ class DeepFakeArchi(nn.ArchiBase):
self.in_ch = in_ch
self.e_ch = e_ch
super().__init__(**kwargs)
def on_build(self):
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
def on_build(self):
if 't' in opts:
self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5)
self.res1 = ResidualBlock(self.e_ch)
self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5)
self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5)
self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5)
self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5)
self.res5 = ResidualBlock(self.e_ch*8)
else:
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5)
def forward(self, x):
if use_fp16:
x = tf.cast(x, tf.float16)
x = nn.flatten(self.down1(x))
if 't' in opts:
x = self.down1(x)
x = self.res1(x)
x = self.down2(x)
x = self.down3(x)
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
else:
x = self.down1(x)
x = nn.flatten(x)
if 'u' in opts:
x = nn.pixel_norm(x, axes=-1)
if use_fp16:
x = tf.cast(x, tf.float32)
return x
def get_out_res(self, res):
return res // (2**4)
return res // ( (2**4) if 't' not in opts else (2**5) )
def get_out_ch(self):
return self.e_ch * 8
@ -106,59 +133,83 @@ class DeepFakeArchi(nn.ArchiBase):
def on_build(self):
in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch
if 'u' in opts:
self.dense_norm = nn.DenseNorm()
self.dense1 = nn.Dense( in_ch, ae_ch )
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
if 't' not in opts:
self.upscale1 = Upscale(ae_out_ch, ae_out_ch)
def forward(self, inp):
x = inp
if 'u' in opts:
x = self.dense_norm(x)
x = self.dense1(x)
x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
if use_fp16:
x = tf.cast(x, tf.float16)
x = self.upscale1(x)
if 't' not in opts:
x = self.upscale1(x)
return x
def get_out_res(self):
return lowest_dense_res * 2
return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
def on_build(self, in_ch, d_ch, d_mask_ch):
if 't' not in opts:
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*4, kernel_size=3)
self.res2 = ResidualBlock(d_ch*2, kernel_size=3)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
if 'd' in opts:
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
if 'd' in opts:
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
if 'd' in opts:
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
else:
self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
def forward(self, z):
x = self.upscale0(z)
x = self.res0(x)
@ -167,6 +218,10 @@ class DeepFakeArchi(nn.ArchiBase):
x = self.upscale2(x)
x = self.res2(x)
if 't' in opts:
x = self.upscale3(x)
x = self.res3(x)
if 'd' in opts:
x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x),
self.out_conv1(x),
@ -179,16 +234,23 @@ class DeepFakeArchi(nn.ArchiBase):
m = self.upscalem0(z)
m = self.upscalem1(m)
m = self.upscalem2(m)
if 'd' in opts:
if 't' in opts:
m = self.upscalem3(m)
if 'd' in opts:
m = self.upscalem4(m)
else:
if 'd' in opts:
m = self.upscalem3(m)
m = tf.nn.sigmoid(self.out_convm(m))
if use_fp16:
x = tf.cast(x, tf.float32)
x = tf.cast(x, tf.float32)
m = tf.cast(m, tf.float32)
return x, m
self.Encoder = Encoder
self.Inter = Inter
self.Decoder = Decoder

View File

@ -35,9 +35,7 @@ class SAEHDModel(ModelBase):
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
archi = self.load_or_def_option('archi', 'liae-ud')
archi = {'dfuhd':'df-u','liaeuhd':'liae-u'}.get(archi, archi) #backward comp
default_archi = self.options['archi'] = archi
default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud')
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
@ -47,7 +45,6 @@ class SAEHDModel(ModelBase):
default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False)
default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False)
default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False)
default_dst_denoise = self.options['dst_denoise'] = self.load_or_def_option('dst_denoise', False)
default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True)
@ -107,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if archi_opts is not None:
if len(archi_opts) == 0:
continue
if len([ 1 for opt in archi_opts if opt not in ['u','d'] ]) != 0:
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
continue
if 'd' in archi_opts:
@ -141,7 +138,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.')
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.')
self.options['dst_denoise'] = io.input_bool ("Denoise DST faceset.", default_dst_denoise, help_message='Used in RTM(ReadyToMerge) training with RTM DST faceset. Removes high frequency noise keeping edges. Result is better face syncronization with any face. Can be enabled at any time.')
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
@ -233,7 +229,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
random_src_flip = self.random_src_flip if not self.pretrain else True
random_dst_flip = self.random_dst_flip if not self.pretrain else True
blur_out_mask = self.options['blur_out_mask']
dst_denoise = self.options['dst_denoise']
learn_dst_bg = False#True
if self.pretrain:
self.options_show_override['gan_power'] = 0.0
@ -327,8 +323,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
elif 'liae' in archi_type:
self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()
self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt')
self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')
self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ]
@ -413,6 +407,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
gpu_pred_src_src_list.append(gpu_pred_src_src)
@ -425,25 +420,30 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) )
gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2
gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur
gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) )
gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary
gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur
gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2
gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur
gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur
gpu_target_dst_style_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_style_blur)
gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur
gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur
gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur
gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur
gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur)
gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src
gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst
gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur)
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur
if resolution < 256:
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
@ -483,6 +483,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss = gpu_src_loss + gpu_dst_loss
if learn_dst_bg and masked_training and 'liae' in archi_type:
gpu_G_loss += tf.reduce_mean( tf.square(gpu_pred_dst_dst_no_code_grad*gpu_target_dstm_anti_blur-gpu_target_dst_anti_masked),axis=[1,2,3] )
def DLoss(labels,logits):
return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3])
@ -526,14 +529,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \
DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2))
if masked_training:
# Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan
gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src)
gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] )
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ]
gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )]
# Average losses and gradients, and create optimizer update ops
@ -560,7 +563,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
# Initializing training and view functions
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
warped_dst, target_dst, target_dstm, target_dstm_em, ):
s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op],
feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
@ -569,7 +572,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.target_dst :target_dst,
self.target_dstm:target_dstm,
self.target_dstm_em:target_dstm_em,
})
})[:2]
return s, d
self.src_dst_train = src_dst_train
@ -674,7 +677,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'denoise_filter' : dst_denoise, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
@ -762,13 +764,13 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
bs = self.get_batch_size()
( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = self.generate_next_samples()
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst_train[i], target_dstm[i], target_dstm_em[i],) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
@ -779,11 +781,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dst_train = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst_train, target_dstm, target_dstm_em)
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
@ -791,14 +792,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.D_train (warped_src, warped_dst)
if self.gan_power != 0:
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em)
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override
def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = samples
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]