mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
AMP: last high loss samples behaviour - same as SAEHD
This commit is contained in:
parent
bfa88c5fd9
commit
4be135af60
|
@ -622,24 +622,22 @@ class AMPModel(ModelBase):
|
|||
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
|
||||
for i in range(bs):
|
||||
self.last_src_samples_loss.append ( (src_loss[i], warped_src[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
|
||||
self.last_dst_samples_loss.append ( (dst_loss[i], warped_dst[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
|
||||
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
|
||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
|
||||
|
||||
if len(self.last_src_samples_loss) >= bs*16:
|
||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||
|
||||
warped_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||
target_src = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm_em = np.stack( [ x[4] for x in src_samples_loss[:bs] ] )
|
||||
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
||||
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||
|
||||
warped_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||
target_dst = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] )
|
||||
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||
|
||||
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
||||
self.last_src_samples_loss = []
|
||||
self.last_dst_samples_loss = []
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user