AMP: code refactoring, fix preview history

added dumpdflive command
This commit is contained in:
iperov 2021-06-26 10:44:41 +04:00
parent 6d89d7fa4c
commit 5783191849
9 changed files with 143 additions and 144 deletions

13
main.py
View File

@ -127,7 +127,6 @@ if __name__ == "__main__":
'silent_start' : arguments.silent_start,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
'debug' : arguments.debug,
'dump_ckpt' : arguments.dump_ckpt,
}
from mainscripts import Trainer
Trainer.main(**kwargs)
@ -145,11 +144,19 @@ if __name__ == "__main__":
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.set_defaults (func=process_train)
def process_dumpdflive(arguments):
osex.set_process_lowest_prio()
from mainscripts import DumpDFLive
DumpDFLive.main(model_class_name = arguments.model_name, saved_models_path = Path(arguments.model_dir))
p = subparsers.add_parser( "dumpdflive", help="Dump model to use in DFLive.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.")
p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.")
p.set_defaults (func=process_dumpdflive)
def process_merge(arguments):
osex.set_process_lowest_prio()

View File

@ -27,7 +27,6 @@ def trainerThread (s2c, c2s, e,
silent_start=False,
execute_programs = None,
debug=False,
dump_ckpt=False,
**kwargs):
while True:
try:
@ -43,12 +42,9 @@ def trainerThread (s2c, c2s, e,
if not saved_models_path.exists():
saved_models_path.mkdir(exist_ok=True, parents=True)
if dump_ckpt:
cpu_only=True
model = models.import_model(model_class_name)(
is_training=not dump_ckpt,
is_training=True,
saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
@ -61,11 +57,6 @@ def trainerThread (s2c, c2s, e,
silent_start=silent_start,
debug=debug)
if dump_ckpt:
e.set()
model.dump_ckpt()
break
is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False }

View File

@ -232,7 +232,7 @@ class ModelBase(object):
preview_id_counter = 0
while not choosed:
self.sample_for_preview = self.generate_next_samples()
previews = self.get_static_previews()
previews = self.get_history_previews()
io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) )
@ -258,7 +258,7 @@ class ModelBase(object):
self.sample_for_preview = self.generate_next_samples()
try:
self.get_static_previews()
self.get_history_previews()
except:
self.sample_for_preview = self.generate_next_samples()
@ -347,7 +347,7 @@ class ModelBase(object):
return ( ('loss_src', 0), ('loss_dst', 0) )
#overridable
def onGetPreview(self, sample):
def onGetPreview(self, sample, for_history=False):
#you can return multiple previews
#return [ ('preview_name',preview_rgb), ... ]
return []
@ -377,8 +377,8 @@ class ModelBase(object):
def get_previews(self):
return self.onGetPreview ( self.last_sample )
def get_static_previews(self):
return self.onGetPreview (self.sample_for_preview)
def get_history_previews(self):
return self.onGetPreview (self.sample_for_preview, for_history=True)
def get_preview_history_writer(self):
if self.preview_history_writer is None:
@ -484,7 +484,7 @@ class ModelBase(object):
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
if self.write_preview_history:
previews = self.get_static_previews()
previews = self.get_history_previews()
for i in range(len(previews)):
name, bgr = previews[i]
path = self.preview_history_path / name

View File

@ -87,7 +87,7 @@ class AMPModel(ModelBase):
d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 )
self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2
morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="The smaller the value, the more src-like facial expressions will appear. The larger the value, the less space there is to train a large dst faceset in the neural network. Typical fine value is 0.33"), 0.1, 0.5 )
self.options['morph_factor'] = morph_factor
@ -121,9 +121,9 @@ class AMPModel(ModelBase):
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.")
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y")
self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims'])
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
@ -137,34 +137,26 @@ class AMPModel(ModelBase):
self.resolution = resolution = self.options['resolution']
input_ch=3
ae_dims = self.ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims']
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
lowest_dense_res = self.lowest_dense_res = resolution // 32
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
super().__init__(*kwargs)
def on_build(self, *args, **kwargs ):
self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME')
def on_build(self, in_ch, out_ch, kernel_size=5 ):
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME')
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
return x
def get_out_ch(self):
return self.out_ch
return tf.nn.leaky_relu(self.conv1(x), 0.1)
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME')
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
x = nn.depth_to_space(x, 2)
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
return x
class ResidualBlock(nn.ModelBase):
@ -180,15 +172,15 @@ class AMPModel(ModelBase):
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch, ae_ch):
self.down1 = Downscale(in_ch, e_ch, kernel_size=5)
self.res1 = ResidualBlock(e_ch)
self.down2 = Downscale(e_ch, e_ch*2, kernel_size=5)
self.down3 = Downscale(e_ch*2, e_ch*4, kernel_size=5)
self.down4 = Downscale(e_ch*4, e_ch*8, kernel_size=5)
self.down5 = Downscale(e_ch*8, e_ch*8, kernel_size=5)
self.res5 = ResidualBlock(e_ch*8)
self.dense1 = nn.Dense( lowest_dense_res*lowest_dense_res*e_ch*8, ae_ch )
def on_build(self):
self.down1 = Downscale(input_ch, e_dims, kernel_size=5)
self.res1 = ResidualBlock(e_dims)
self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5)
self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5)
self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5)
self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5)
self.res5 = ResidualBlock(e_dims*8)
self.dense1 = nn.Dense( lowest_dense_res*lowest_dense_res*e_dims*8, ae_dims )
def forward(self, inp):
x = inp
@ -199,53 +191,45 @@ class AMPModel(ModelBase):
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
x = nn.flatten(x)
x = nn.pixel_norm(x, axes=-1)
x = nn.pixel_norm(nn.flatten(x), axes=-1)
x = self.dense1(x)
return x
class Inter(nn.ModelBase):
def __init__(self, ae_ch, ae_out_ch, **kwargs):
self.ae_ch, self.ae_out_ch = ae_ch, ae_out_ch
super().__init__(**kwargs)
def on_build(self):
ae_ch, ae_out_ch = self.ae_ch, self.ae_out_ch
self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch )
self.dense2 = nn.Dense( ae_dims, lowest_dense_res * lowest_dense_res * ae_dims )
def forward(self, inp):
x = inp
x = self.dense2(x)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch)
x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, ae_dims)
return x
def get_out_ch(self):
return self.ae_out_ch
class Decoder(nn.ModelBase):
def on_build(self, in_ch, d_ch, d_mask_ch ):
self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3)
self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3)
self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3)
self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3)
def on_build(self ):
self.upscale0 = Upscale(ae_dims, d_dims*8, kernel_size=3)
self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3)
self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3)
self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3)
self.res0 = ResidualBlock(d_ch*8, kernel_size=3)
self.res1 = ResidualBlock(d_ch*8, kernel_size=3)
self.res2 = ResidualBlock(d_ch*4, kernel_size=3)
self.res3 = ResidualBlock(d_ch*2, kernel_size=3)
self.res0 = ResidualBlock(d_dims*8, kernel_size=3)
self.res1 = ResidualBlock(d_dims*8, kernel_size=3)
self.res2 = ResidualBlock(d_dims*4, kernel_size=3)
self.res3 = ResidualBlock(d_dims*2, kernel_size=3)
self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3)
self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3)
self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME')
self.upscalem0 = Upscale(ae_dims, d_mask_dims*8, kernel_size=3)
self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3)
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME')
self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME')
self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME')
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME')
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME')
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME')
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME')
def forward(self, inp):
z = inp
@ -280,27 +264,24 @@ class AMPModel(ModelBase):
eyes_mouth_prio = self.options['eyes_mouth_prio']
ae_dims = self.ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims']
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
morph_factor = self.options['morph_factor']
pretrain = self.pretrain = self.options['pretrain']
if self.pretrain_just_disabled:
self.set_iter(0)
self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power']
random_warp = False if self.pretrain else self.options['random_warp']
random_src_flip = self.random_src_flip if not self.pretrain else True
random_dst_flip = self.random_dst_flip if not self.pretrain else True
if self.pretrain:
self.options_show_override['gan_power'] = 0.0
self.options_show_override['random_warp'] = False
self.options_show_override['lr_dropout'] = 'n'
self.options_show_override['uniform_yaw'] = True
masked_training = self.options['masked_training']
ct_mode = self.options['ct_mode']
if ct_mode == 'none':
@ -310,7 +291,7 @@ class AMPModel(ModelBase):
models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0'
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
input_ch=3
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = []
@ -333,10 +314,10 @@ class AMPModel(ModelBase):
# Initializing model classes
with tf.device (models_opt_device):
self.encoder = Encoder(in_ch=input_ch, e_ch=e_dims, ae_ch=ae_dims, name='encoder')
self.inter_src = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_src')
self.inter_dst = Inter(ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter_dst')
self.decoder = Decoder(in_ch=ae_dims, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')
self.encoder = Encoder(name='encoder')
self.inter_src = Inter(name='inter_src')
self.inter_dst = Inter(name='inter_dst')
self.decoder = Decoder(name='decoder')
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
[self.inter_src, 'inter_src.npy'],
@ -351,7 +332,7 @@ class AMPModel(ModelBase):
# Initialize optimizers
lr=5e-5
lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0
clipnorm = 1.0 if self.options['clipgrad'] else 0.0
self.all_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights()
@ -386,8 +367,6 @@ class AMPModel(ModelBase):
gpu_src_losses = []
gpu_dst_losses = []
gpu_G_loss_gvs = []
gpu_GAN_loss_gvs = []
gpu_D_code_loss_gvs = []
gpu_D_src_dst_loss_gvs = []
for gpu_id in range(gpu_count):
@ -407,7 +386,7 @@ class AMPModel(ModelBase):
# process model tensors
gpu_src_code = self.encoder (gpu_warped_src)
gpu_dst_code = self.encoder (gpu_warped_dst)
if pretrain:
gpu_src_inter_src_code = self.inter_src (gpu_src_code)
gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code)
@ -454,7 +433,7 @@ class AMPModel(ModelBase):
gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur)
gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst
gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur)
if resolution < 256:
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
else:
@ -481,12 +460,12 @@ class AMPModel(ModelBase):
gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] )
else:
gpu_src_loss = gpu_dst_loss
gpu_src_losses += [gpu_src_loss]
if pretrain:
gpu_G_loss = gpu_dst_loss
else:
else:
gpu_G_loss = gpu_src_loss + gpu_dst_loss
def DLossOnes(logits):
@ -537,8 +516,6 @@ class AMPModel(ModelBase):
if gan_power != 0:
src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) )
#GAN_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gvs) )
# Initializing training and view functions
def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \
@ -569,7 +546,6 @@ class AMPModel(ModelBase):
self.target_dstm_em:target_dstm_em})
self.D_src_dst_train = D_src_dst_train
def AE_view(warped_src, warped_dst, morph_value):
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] })
@ -605,13 +581,11 @@ class AMPModel(ModelBase):
if self.is_training and gan_power != 0 and model == self.GAN:
if self.gan_model_changed:
do_init = True
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
if do_init:
model.init_weights()
###############
# initializing sample generators
@ -621,7 +595,6 @@ class AMPModel(ModelBase):
random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
cpu_count = min(multiprocessing.cpu_count(), 8)
src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2
@ -654,9 +627,13 @@ class AMPModel(ModelBase):
self.last_dst_samples_loss = []
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
def dump_ckpt(self):
def dump_dflive (self):
output_path=self.get_strpath_storage_for_file('model.dflive')
io.log_info(f'Dumping .dflive to {output_path}')
tf = nn.tf
with tf.device (nn.tf_default_device_name):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
@ -667,9 +644,9 @@ class AMPModel(ModelBase):
gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code)
gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code)
ae_dims_slice = tf.cast(self.ae_dims*morph_value[0], tf.int32)
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , self.lowest_dense_res, self.lowest_dense_res]),
tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,self.ae_dims-ae_dims_slice, self.lowest_dense_res,self.lowest_dense_res]) ), 1 )
inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32)
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 )
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)
@ -688,9 +665,15 @@ class AMPModel(ModelBase):
['out_face_mask','out_celeb_face','out_celeb_face_mask']
)
pb_filepath = self.get_strpath_storage_for_file('.pb')
with tf.gfile.GFile(pb_filepath, "wb") as f:
f.write(output_graph_def.SerializeToString())
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='AMP',
input_names=['in_face:0','morph_value:0'],
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
opset=13,
output_path=output_path)
#override
def get_model_filename_list(self):
@ -716,22 +699,24 @@ class AMPModel(ModelBase):
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) )
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) )
self.last_src_samples_loss.append ( (src_loss[i], warped_src[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
self.last_dst_samples_loss.append ( (dst_loss[i], warped_dst[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
warped_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_src = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[4] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
warped_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dst = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
@ -741,7 +726,7 @@ class AMPModel(ModelBase):
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override
def onGetPreview(self, samples):
def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
@ -771,7 +756,7 @@ class AMPModel(ModelBase):
result = []
i = np.random.randint(n_samples)
i = np.random.randint(n_samples) if not for_history else 0
st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ]
st += [ np.concatenate ((SS[i], DD[i], SD_075[i] ), axis=1) ]
@ -782,7 +767,6 @@ class AMPModel(ModelBase):
st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ]
result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ]
st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ]
st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ]
result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ]
@ -791,7 +775,7 @@ class AMPModel(ModelBase):
def predictor_func (self, face, morph_value):
face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC")
bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ]
return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0]
@ -802,9 +786,9 @@ class AMPModel(ModelBase):
def predictor_morph(face):
return self.predictor_func(face, morph_factor)
import merger
import merger
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
Model = AMPModel

View File

@ -278,7 +278,7 @@ class QModel(ModelBase):
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override
def onGetPreview(self, samples):
def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = samples

View File

@ -659,11 +659,15 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
def dump_ckpt(self):
def dump_dflive (self):
output_path=self.get_strpath_storage_for_file('model.dflive')
io.log_info(f'Dumping .dflive to {output_path}')
tf = nn.tf
nn.set_data_format('NCHW')
with tf.device ('/CPU:0'):
with tf.device (nn.tf_default_device_name):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
@ -687,15 +691,26 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
saver = tf.train.Saver()
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_face_mask','out_celeb_face','out_celeb_face_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='SAEHD',
input_names=['in_face:0'],
output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
opset=13,
output_path=output_path)
#override
def get_model_filename_list(self):
return self.model_filename_list
@ -751,7 +766,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override
def onGetPreview(self, samples):
def onGetPreview(self, samples, for_history=False):
( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples

View File

@ -164,7 +164,7 @@ class XSegModel(ModelBase):
return ( ('loss', np.mean(loss) ), )
#override
def onGetPreview(self, samples):
def onGetPreview(self, samples, for_history=False):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
srcdst_samples, src_samples, dst_samples = samples

View File

@ -6,4 +6,5 @@ ffmpeg-python==0.1.17
scikit-image==0.14.2
scipy==1.4.1
colorama
tensorflow-gpu==2.3.1
tensorflow-gpu==2.3.1
tf2onnx==1.8.4

View File

@ -8,3 +8,4 @@ scipy==1.4.1
colorama
tensorflow-gpu==2.4.0
pyqt5
tf2onnx==1.8.4