XSeg: changed pretrain mode

This commit is contained in:
iperov 2021-08-12 17:01:38 +04:00
parent 623eb3856d
commit c5584fbda0

View File

@ -34,7 +34,7 @@ class XSegModel(ModelBase):
self.ask_batch_size(4, range=[2,16])
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None):
raise Exception("pretraining_data_path is not defined")
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
@ -42,7 +42,7 @@ class XSegModel(ModelBase):
#override
def on_initialize(self):
device_config = nn.getCurrentDeviceConfig()
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC"
nn.initialize(data_format=self.model_data_format)
tf = nn.tf
@ -85,8 +85,6 @@ class XSegModel(ModelBase):
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
self.set_batch_size( gpu_count*bs_per_gpu)
targetm_t = tf.placeholder (nn.floatx, mask_shape)
# Compute losses per GPU
gpu_pred_list = []
@ -100,7 +98,6 @@ class XSegModel(ModelBase):
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
gpu_targetm_t = targetm_t [batch_slice,:,:,:]
# process model tensors
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain)
@ -108,17 +105,11 @@ class XSegModel(ModelBase):
if self.pretrain:
gpu_targetm_blur = nn.gaussian_blur(gpu_targetm_t, max(1, resolution // 32) )
gpu_targetm_blur = tf.clip_by_value(gpu_targetm_blur, 0, 0.5) * 2
gpu_target_t_blur = gpu_target_t*gpu_targetm_blur
gpu_pred_t_blur = gpu_pred_t*gpu_targetm_t
# Structural loss
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
# Pixel loss
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t_blur-gpu_pred_t_blur), axis=[1,2,3])
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3])
else:
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
@ -137,8 +128,8 @@ class XSegModel(ModelBase):
# Initializing training and view functions
if self.pretrain:
def train(input_np, target_np, targetm_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np, targetm_t :targetm_np })
def train(input_np, target_np):
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np})
return l
else:
def train(input_np, target_np):
@ -160,8 +151,7 @@ class XSegModel(ModelBase):
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
sample_process_options=SampleProcessor.Options(random_flip=True),
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
],
uniform_yaw_distribution=False,
generators_count=cpu_count )
@ -200,13 +190,9 @@ class XSegModel(ModelBase):
#override
def onTrainOneIter(self):
if self.pretrain:
image_np, target_np, targetm_np = self.generate_next_samples()[0]
loss = self.train (image_np, target_np, targetm_np)
else:
image_np, mask_np = self.generate_next_samples()[0]
loss = self.train (image_np, mask_np)
image_np, target_np = self.generate_next_samples()[0]
loss = self.train (image_np, target_np)
return ( ('loss', np.mean(loss) ), )
#override
@ -215,7 +201,7 @@ class XSegModel(ModelBase):
if self.pretrain:
srcdst_samples, = samples
image_np, mask_np, _ = srcdst_samples
image_np, mask_np = srcdst_samples
else:
srcdst_samples, src_samples, dst_samples = samples
image_np, mask_np = srcdst_samples
@ -264,5 +250,34 @@ class XSegModel(ModelBase):
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
return result
def export_dfm (self):
output_path = self.get_strpath_storage_for_file(f'model.onnx')
io.log_info(f'Dumping .onnx to {output_path}')
tf = nn.tf
with tf.device (nn.tf_default_device_name):
input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
input_t = tf.transpose(input_t, (0,3,1,2))
_, pred_t = self.model.flow(input_t)
pred_t = tf.transpose(pred_t, (0,2,3,1))
tf.identity(pred_t, name='out_mask')
output_graph_def = tf.graph_util.convert_variables_to_constants(
nn.tf_sess,
tf.get_default_graph().as_graph_def(),
['out_mask']
)
import tf2onnx
with tf.device("/CPU:0"):
model_proto, _ = tf2onnx.convert._convert_common(
output_graph_def,
name='XSeg',
input_names=['in_face:0'],
output_names=['out_mask:0'],
opset=13,
output_path=output_path)
Model = XSegModel