File size: 4,090 Bytes
4ed0ba3
 
 
 
 
 
 
 
c6b3e88
 
4ed0ba3
 
 
 
 
1250540
4ed0ba3
 
 
 
 
 
1250540
4ed0ba3
 
 
1250540
4ed0ba3
 
 
 
1250540
4ed0ba3
 
1250540
4ed0ba3
 
1250540
4ed0ba3
 
 
 
 
 
1250540
4ed0ba3
1250540
 
4ed0ba3
 
 
1250540
4ed0ba3
1250540
4ed0ba3
 
1250540
4ed0ba3
 
 
 
 
 
 
 
1250540
4ed0ba3
1250540
4ed0ba3
 
1250540
4ed0ba3
 
 
1250540
 
 
 
 
 
 
 
4ed0ba3
 
 
 
 
 
 
 
 
 
 
1250540
 
 
 
 
 
4ed0ba3
 
 
 
302efca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.distributions.categorical import Categorical
import numpy as np
import pandas as pd

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("ChatterjeeLab/PepMLM-650M")
model = AutoModelForMaskedLM.from_pretrained("ChatterjeeLab/PepMLM-650M")

def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
    sequence = protein_seq + binder_seq
    tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
    total_loss = 0
    
    # Loop through each token in the binder sequence
    for i in range(-len(binder_seq)-1, -1):
        # Create a copy of the original tensor
        masked_input = tensor_input.clone()
        # Mask one token at a time
        masked_input[0, i] = tokenizer.mask_token_id
        
        # Create labels
        labels = torch.full(tensor_input.shape, -100).to(model.device)
        labels[0, i] = tensor_input[0, i]
        
        # Get model prediction and loss
        with torch.no_grad():
            outputs = model(masked_input, labels=labels)
            total_loss += outputs.loss.item()
    
    # Calculate the average loss
    avg_loss = total_loss / len(binder_seq)
    
    # Calculate pseudo perplexity
    pseudo_perplexity = np.exp(avg_loss)
    
    return pseudo_perplexity
    
def generate_peptide(protein_seq, peptide_length, top_k, num_binders):
    peptide_length = int(peptide_length)
    top_k = int(top_k)
    num_binders = int(num_binders)
    
    binders_with_ppl = []
    
    for _ in range(num_binders):  # Fixed: underscore instead of asterisk
        # Generate binder
        masked_peptide = '<mask>' * peptide_length
        input_sequence = protein_seq + masked_peptide
        
        inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            logits = model(**inputs).logits
        
        mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        logits_at_masks = logits[0, mask_token_indices]
    
        # Apply top-k sampling
        top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
        probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
        predicted_indices = Categorical(probabilities).sample()
        predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
        
        generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
        
        # Compute PPL for the generated binder
        ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
        
        # Add the generated binder and its PPL to the results list
        binders_with_ppl.append([generated_binder, ppl_value])
    
    # Convert the list of lists to a pandas dataframe
    df = pd.DataFrame(binders_with_ppl, columns=["Binder", "Perplexity"])
    
    # Save the dataframe to a CSV file
    output_filename = "output.csv"
    df.to_csv(output_filename, index=False)
    
    return df, output_filename  # Return dataframe instead of list

# Define the Gradio interface
interface = gr.Interface(
    fn=generate_peptide,
    inputs=[
        gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"),
        gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'),
        gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'),
        gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=1)
    ],
    outputs=[
        gr.Dataframe(
            headers=["Binder", "Perplexity"],
            datatype=["str", "number"],
            col_count=(2, "fixed")
        ),
        gr.File(label="Download CSV")  # Fixed: Use gr.File instead of gr.outputs.File
    ],
    title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling"
)

interface.launch()