akhaliq's picture
akhaliq HF Staff
Upload 39 files
f06aba5 verified
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
import numpy as np
class LogBuffer:
def __init__(self):
self.val_history = OrderedDict()
self.n_history = OrderedDict()
self.output = OrderedDict()
self.ready = False
def clear(self) -> None:
self.val_history.clear()
self.n_history.clear()
self.clear_output()
def clear_output(self) -> None:
self.output.clear()
self.ready = False
def update(self, vars: dict, count: int = 1) -> None:
assert isinstance(vars, dict)
for key, var in vars.items():
if key not in self.val_history:
self.val_history[key] = []
self.n_history[key] = []
self.val_history[key].append(var)
self.n_history[key].append(count)
def average(self, n: int = 0) -> None:
"""Average latest n values or all values."""
assert n >= 0
for key in self.val_history:
values = np.array(self.val_history[key][-n:])
nums = np.array(self.n_history[key][-n:])
avg = np.sum(values * nums) / np.sum(nums)
self.output[key] = avg
self.ready = True