| import torch | |
| import numpy as np | |
| from scipy.linalg import sqrtm | |
| def dna_to_tensor(seq): | |
| mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} | |
| indices = [mapping[base] for base in seq] | |
| return torch.tensor(indices, dtype=torch.long) | |
| def compute_fbd(true_seqs, gen_seqs, score_model): | |
| """ | |
| The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings | |
| """ | |
| embeds1 = score_model() | |
| embeds2 = score_model() | |
| if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0: | |
| return float('nan') | |
| mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False) | |
| mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False) | |
| ssdiff = np.sum((mu1 - mu2) ** 2.0) | |
| covmean = sqrtm(sigma1.dot(sigma2)) | |
| if np.iscomplexobj(covmean): | |
| covmean = covmean.real | |
| dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) | |
| return dist | |