File size: 965 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import numpy as np
from scipy.linalg import sqrtm

def dna_to_tensor(seq):
    mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    indices = [mapping[base] for base in seq]
    return torch.tensor(indices, dtype=torch.long)


def compute_fbd(true_seqs, gen_seqs, score_model):
    """
    The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings
    """
    embeds1 = score_model()
    embeds2 = score_model()

    if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
        return float('nan')
    mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
    mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return dist