This commit is contained in:
iperov 2021-10-17 21:37:22 +04:00
parent d8c7cc3d93
commit 14cc9d4e5f

View File

@ -4,16 +4,16 @@ from core.leras import nn
tf = nn.tf
class RMSprop(nn.OptimizerBase):
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, clipnorm=0.0, name=None, **kwargs):
def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs):
super().__init__(name=name)
if name is None:
raise ValueError('name must be defined.')
self.lr_dropout = lr_dropout
self.lr_cos = lr_cos
self.lr = lr
self.rho = rho
self.clipnorm = clipnorm
with tf.device('/CPU:0') :
@ -58,6 +58,8 @@ class RMSprop(nn.OptimizerBase):
new_a = self.rho * a + (1. - self.rho) * tf.square(g)
lr = tf.constant(self.lr, g.dtype)
if self.lr_cos != 0:
lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0
v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution )
if self.lr_dropout != 1.0: