MadSBM / src /utils /time_utils.py
Shrey Goel
initial commit
94c2704
import torch
import torch.nn as nn
# -------------------------
# Timestep embeddings
# -------------------------
class GaussianFourierProjection(nn.Module):
"""
Gaussian Fourier features for continuous time t in [0, 1].
Produces 2 * embed_dim features: [sin(W t), cos(W t)].
"""
def __init__(self, embed_dim, scale):
super().__init__()
assert embed_dim % 2 == 0, "embed_dim must be even."
self.embed_dim = embed_dim
self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False) # Fixed random frequencies
def forward(self, t):
# Ensure float
t = t.float().unsqueeze(-1) # Broadcoast to [B, 1]
angles = t * self.W # B, embed_dim // 2
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
class TimeEmbedding(nn.Module):
def __init__(self, hidden_dim, fourier_dim, scale):
super().__init__()
assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs."
self.fourier = GaussianFourierProjection(fourier_dim, scale)
self.mlp = nn.Sequential(
nn.Linear(fourier_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
def forward(self, t):
ft = self.fourier(t) # (B, fourier_dim)
return self.mlp(ft) # (B, hidden_dim)