MadSBM / src /utils /model_utils.py
Shrey Goel
initial commit
94c2704
import sys
import math
import torch
import torch.nn as nn
import numpy as np
def _print(s):
print(s)
sys.stdout.flush()
def compute_grad_norms(params):
""" Compute the norms of a matrix of gradients """
sqrd_sum = 0.0
for p in params:
if p.grad != None:
sqrd_sum += p.grad.norm(2).item() ** 2
norm = sqrd_sum ** 0.5
return norm
class CosineWarmup(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
super(CosineWarmup, self).__init__(optimizer, last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
return [
base_lr * self.last_epoch / self.warmup_steps
for base_lr in self.base_lrs
]
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
lr_mult = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
return [base_lr * lr_mult for base_lr in self.base_lrs]