From b1b5d6f482f93b1f987bd2321b7cc8f48d2787c5 Mon Sep 17 00:00:00 2001 From: iperov Date: Sat, 9 Oct 2021 13:58:46 +0400 Subject: [PATCH] AMP, SAEHD: In the sample generator, the random scaling was increased from -0.05+0.05 to -0.125+0.125 , which improves the generalization of faces. --- models/Model_AMP/Model.py | 4 ++-- models/Model_SAEHD/Model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index ae8eb61..04a1c26 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -561,7 +561,7 @@ class AMPModel(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip), + sample_process_options=SampleProcessor.Options(scale_range=[-0.125, 0.125], random_flip=self.random_src_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, @@ -571,7 +571,7 @@ class AMPModel(ModelBase): generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip), + sample_process_options=SampleProcessor.Options(scale_range=[-0.125, 0.125], random_flip=self.random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index e0829ec..394b010 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -668,7 +668,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=random_src_flip), + sample_process_options=SampleProcessor.Options(scale_range=[-0.125, 0.125], random_flip=random_src_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, @@ -678,7 +678,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), + sample_process_options=SampleProcessor.Options(scale_range=[-0.125, 0.125], random_flip=random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},