dump_ckpt

This commit is contained in:
iperov 2021-03-23 15:00:24 +04:00
parent 3d0e18b0ad
commit b333fcea4b
3 changed files with 62 additions and 13 deletions

View File

@ -127,6 +127,7 @@ if __name__ == "__main__":
'silent_start' : arguments.silent_start,
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
'debug' : arguments.debug,
'dump_ckpt' : arguments.dump_ckpt,
}
from mainscripts import Trainer
Trainer.main(**kwargs)
@ -144,6 +145,7 @@ if __name__ == "__main__":
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')

View File

@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e,
silent_start=False,
execute_programs = None,
debug=False,
dump_ckpt=False,
**kwargs):
while True:
try:
@ -44,7 +45,7 @@ def trainerThread (s2c, c2s, e,
saved_models_path.mkdir(exist_ok=True, parents=True)
model = models.import_model(model_class_name)(
is_training=True,
is_training=not dump_ckpt,
saved_models_path=saved_models_path,
training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path,
@ -55,9 +56,13 @@ def trainerThread (s2c, c2s, e,
force_gpu_idxs=force_gpu_idxs,
cpu_only=cpu_only,
silent_start=silent_start,
debug=debug,
)
debug=debug)
if dump_ckpt:
e.set()
model.dump_ckpt()
break
is_reached_goal = model.is_reached_iter_goal()
shared_state = { 'after_save' : False }

View File

@ -204,6 +204,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
archi_type, archi_opts = archi_split
elif len(archi_split) == 1:
archi_type, archi_opts = archi_split[0], None
self.archi_type = archi_type
ae_dims = self.options['ae_dims']
e_dims = self.options['e_dims']
@ -236,22 +238,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
input_ch=3
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
mask_shape = nn.get4Dshape(resolution,resolution,1)
self.model_filename_list = []
with tf.device ('/CPU:0'):
#Place holders on CPU
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape)
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape)
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
# Initializing model classes
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
@ -609,7 +611,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if do_init:
model.init_weights()
###############
# initializing sample generators
if self.is_training:
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
@ -650,7 +655,44 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
def dump_ckpt(self):
tf = nn.tf
with tf.device ('/CPU:0'):
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
if 'df' in self.archi_type:
gpu_dst_code = self.inter(self.encoder(warped_dst))
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in self.archi_type:
gpu_dst_code = self.encoder (warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
saver = tf.train.Saver()
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
#override
def get_model_filename_list(self):
return self.model_filename_list