File size: 4,271 Bytes
3527383
 
6409d51
 
3527383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16339c9
3527383
 
a33cd10
3527383
a620d8f
 
 
 
 
 
 
 
16339c9
a620d8f
 
 
 
 
2237f88
16339c9
a620d8f
 
a33cd10
9fdebfd
 
 
 
16339c9
a620d8f
3527383
9fdebfd
6409d51
 
 
 
 
9fdebfd
6409d51
 
 
 
 
 
3527383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fdebfd
3527383
 
 
 
 
 
 
 
 
 
 
 
 
 
6409d51
 
 
3527383
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from transformers import AutoTokenizer
from pathlib import Path
import inspect

from models.peptide_classifiers import *
from utils.parsing import parse_guidance_args
args = parse_guidance_args()


# MOO hyper-parameters
step_size = 1 / 100
n_samples = 1
vocab_size = 24
source_distribution = "uniform"
device = 'cuda:0'

length = args.length
target = args.target_protein
if args.motifs:
    motifs = parse_motifs(args.motifs).to(device)
    print(motifs)

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
target_sequence = tokenizer(target, return_tensors='pt').to(device)

# Load Models
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)

score_models = []
if 'Hemolysis' in args.objectives:
    hemolysis_model = HemolysisModel(device=device)
    score_models.append(hemolysis_model)
if 'Non-Fouling' in args.objectives:
    nonfouling_model = NonfoulingModel(device=device)
    score_models.append(nonfouling_model)
if 'Solubility' in args.objectives:
    solubility_model = SolubilityModel(device=device)
    score_models.append(solubility_model)
if 'Half-Life' in args.objectives:
    halflife_model = HalfLifeModel(device=device)
    score_models.append(halflife_model)
if 'Affinity' in args.objectives:
    affinity_predictor = load_affinity_predictor(device)
    affinity_model = AffinityModel(affinity_predictor, target_sequence, device)
    score_models.append(affinity_model)
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
    bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
    if 'Specificity' in args.objectives:
        motif_penalty = True
    else: 
        motif_penalty = False
    motif_model = MotifModel(bindevaluator, target_sequence['input_ids'], motifs, penalty=motif_penalty)
    score_models.append(motif_model)

objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'

if Path(args.output_file).exists():
    with open(args.output_file, 'r') as f:
        lines = f.readlines()

    if len(lines) == 0 or lines[0] != objective_line:
        with open(args.output_file, 'w') as f:
            f.write(objective_line)
else:
    with open(args.output_file, 'w') as f:
            f.write(objective_line)

for i in range(args.n_batches):
    if source_distribution == "uniform":
        x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device)   # CHANGE!
    elif source_distribution == "mask":
        x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
    else:
        raise NotImplementedError

    zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
    twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
    x_init = torch.cat([zeros, x_init, twos], dim=1)

    x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, 
                                      step_size=step_size, 
                                      verbose=True, 
                                      time_grid=torch.tensor([0.0, 1.0-1e-3]),
                                      score_models=score_models,
                                      num_objectives=len(score_models) + int(motif_penalty),
                                      weights=args.weights)
    
    samples = x_1.tolist()
    samples = [tokenizer.decode(seq).replace(' ', '')[5:-5] for seq in samples]
    print(samples)
    
    scores = []
    for i, s in enumerate(score_models):
        sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
        if 't' in sig.parameters:
            candidate_scores = s(x_1, 1)
        else:
            candidate_scores = s(x_1)

        if args.objectives[i] == 'Affinity':
            candidate_scores = 10 * candidate_scores

        if isinstance(candidate_scores, tuple):
            for score in candidate_scores:
                scores.append(score.item())
        else:
            scores.append(candidate_scores.item())
    print(scores)

    with open(args.output_file, 'a') as f:
        f.write(samples[0])
        for score in scores:
            f.write(f",{score}")
        f.write('\n')