MadSBM / src /utils /eval_utils.py
Shrey Goel
initial commit
94c2704
import torch
import numpy as np
from scipy.linalg import sqrtm
def dna_to_tensor(seq):
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
indices = [mapping[base] for base in seq]
return torch.tensor(indices, dtype=torch.long)
def compute_fbd(true_seqs, gen_seqs, score_model):
"""
The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings
"""
embeds1 = score_model()
embeds2 = score_model()
if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
return float('nan')
mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
ssdiff = np.sum((mu1 - mu2) ** 2.0)
covmean = sqrtm(sigma1.dot(sigma2))
if np.iscomplexobj(covmean):
covmean = covmean.real
dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return dist