mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
6abdd0ebf0
* Refactor model * Refactor and fix bug to save plots
74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import imp
|
|
import numpy as np
|
|
|
|
class Base(nn.Module):
|
|
def __init__(self, stop_threshold):
|
|
super().__init__()
|
|
|
|
self.init_model()
|
|
self.num_params()
|
|
|
|
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
|
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
|
|
|
@property
|
|
def r(self):
|
|
return self.decoder.r.item()
|
|
|
|
@r.setter
|
|
def r(self, value):
|
|
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
|
|
|
def init_model(self):
|
|
for p in self.parameters():
|
|
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
|
|
|
def finetune_partial(self, whitelist_layers):
|
|
self.zero_grad()
|
|
for name, child in self.named_children():
|
|
if name in whitelist_layers:
|
|
print("Trainable Layer: %s" % name)
|
|
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
|
for param in child.parameters():
|
|
param.requires_grad = False
|
|
|
|
def get_step(self):
|
|
return self.step.data.item()
|
|
|
|
def reset_step(self):
|
|
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
|
self.step = self.step.data.new_tensor(1)
|
|
|
|
def log(self, path, msg):
|
|
with open(path, "a") as f:
|
|
print(msg, file=f)
|
|
|
|
def load(self, path, device, optimizer=None):
|
|
# Use device of model params as location for loaded state
|
|
checkpoint = torch.load(str(path), map_location=device)
|
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
|
|
|
if "optimizer_state" in checkpoint and optimizer is not None:
|
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
|
|
def save(self, path, optimizer=None):
|
|
if optimizer is not None:
|
|
torch.save({
|
|
"model_state": self.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
}, str(path))
|
|
else:
|
|
torch.save({
|
|
"model_state": self.state_dict(),
|
|
}, str(path))
|
|
|
|
|
|
def num_params(self, print_out=True):
|
|
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
|
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
|
if print_out:
|
|
print("Trainable Parameters: %.3fM" % parameters)
|
|
return parameters
|