|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GaussianFourierProjection(nn.Module): |
|
|
""" |
|
|
Gaussian Fourier features for continuous time t in [0, 1]. |
|
|
Produces 2 * embed_dim features: [sin(W t), cos(W t)]. |
|
|
""" |
|
|
def __init__(self, embed_dim, scale): |
|
|
super().__init__() |
|
|
assert embed_dim % 2 == 0, "embed_dim must be even." |
|
|
self.embed_dim = embed_dim |
|
|
self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False) |
|
|
|
|
|
def forward(self, t): |
|
|
|
|
|
t = t.float().unsqueeze(-1) |
|
|
angles = t * self.W |
|
|
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
|
|
|
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
|
def __init__(self, hidden_dim, fourier_dim, scale): |
|
|
super().__init__() |
|
|
assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs." |
|
|
|
|
|
self.fourier = GaussianFourierProjection(fourier_dim, scale) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(fourier_dim, hidden_dim), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
) |
|
|
|
|
|
def forward(self, t): |
|
|
ft = self.fourier(t) |
|
|
return self.mlp(ft) |