mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
_
This commit is contained in:
parent
d8c7cc3d93
commit
14cc9d4e5f
|
@ -4,16 +4,16 @@ from core.leras import nn
|
|||
tf = nn.tf
|
||||
|
||||
class RMSprop(nn.OptimizerBase):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, clipnorm=0.0, name=None, **kwargs):
|
||||
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
|
||||
super().__init__(name=name)
|
||||
|
||||
if name is None:
|
||||
raise ValueError('name must be defined.')
|
||||
|
||||
self.lr_dropout = lr_dropout
|
||||
self.lr_cos = lr_cos
|
||||
self.lr = lr
|
||||
self.rho = rho
|
||||
|
||||
self.clipnorm = clipnorm
|
||||
|
||||
with tf.device('/CPU:0') :
|
||||
|
@ -58,6 +58,8 @@ class RMSprop(nn.OptimizerBase):
|
|||
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
|
||||
|
||||
lr = tf.constant(self.lr, g.dtype)
|
||||
if self.lr_cos != 0:
|
||||
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
|
||||
|
||||
v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
|
||||
if self.lr_dropout != 1.0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user