mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
74a3fc97d0
Need readme
46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
import torch
|
|
|
|
|
|
_output_ref = None
|
|
_replicas_ref = None
|
|
|
|
def data_parallel_workaround(model, *input):
|
|
global _output_ref
|
|
global _replicas_ref
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
output_device = device_ids[0]
|
|
replicas = torch.nn.parallel.replicate(model, device_ids)
|
|
# input.shape = (num_args, batch, ...)
|
|
inputs = torch.nn.parallel.scatter(input, device_ids)
|
|
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
|
replicas = replicas[:len(inputs)]
|
|
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
|
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
|
_output_ref = outputs
|
|
_replicas_ref = replicas
|
|
return y_hat
|
|
|
|
|
|
class ValueWindow():
|
|
def __init__(self, window_size=100):
|
|
self._window_size = window_size
|
|
self._values = []
|
|
|
|
def append(self, x):
|
|
self._values = self._values[-(self._window_size - 1):] + [x]
|
|
|
|
@property
|
|
def sum(self):
|
|
return sum(self._values)
|
|
|
|
@property
|
|
def count(self):
|
|
return len(self._values)
|
|
|
|
@property
|
|
def average(self):
|
|
return self.sum / max(1, self.count)
|
|
|
|
def reset(self):
|
|
self._values = []
|