mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
fix for clipgrad
This commit is contained in:
parent
35877dbfd7
commit
2edac3df8c
|
@ -50,11 +50,11 @@ class AdaBelief(nn.OptimizerBase):
|
|||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, norm)
|
||||
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
|
||||
|
||||
ms = self.ms_dict[ v.name ]
|
||||
vs = self.vs_dict[ v.name ]
|
||||
|
|
|
@ -47,11 +47,11 @@ class RMSprop(nn.OptimizerBase):
|
|||
updates = []
|
||||
|
||||
if self.clipnorm > 0.0:
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(g)) for g,v in grads_vars]))
|
||||
norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars]))
|
||||
updates += [ state_ops.assign_add( self.iterations, 1) ]
|
||||
for i, (g,v) in enumerate(grads_vars):
|
||||
if self.clipnorm > 0.0:
|
||||
g = self.tf_clip_norm(g, self.clipnorm, norm)
|
||||
g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) )
|
||||
|
||||
a = self.accumulators_dict[ v.name ]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user