import torch import numpy as np from scipy.linalg import sqrtm def dna_to_tensor(seq): mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} indices = [mapping[base] for base in seq] return torch.tensor(indices, dtype=torch.long) def compute_fbd(true_seqs, gen_seqs, score_model): """ The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings """ embeds1 = score_model() embeds2 = score_model() if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0: return float('nan') mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False) mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False) ssdiff = np.sum((mu1 - mu2) ** 2.0) covmean = sqrtm(sigma1.dot(sigma2)) if np.iscomplexobj(covmean): covmean = covmean.real dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) return dist