| | import numpy as np |
| | import torch |
| | from transformers import AutoTokenizer, Pipeline |
| |
|
| |
|
| | class TextGenerationPipeline(Pipeline): |
| | def __init__(self, model, **kwargs): |
| | super().__init__(model=model, **kwargs) |
| | |
| | model_name = "InstaDeepAI/ChatNT" |
| | self.english_tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, subfolder="english_tokenizer" |
| | ) |
| | self.bio_tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, subfolder="bio_tokenizer" |
| | ) |
| |
|
| | def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]: |
| | preprocess_kwargs = {} |
| | forward_kwargs = {} |
| | postprocess_kwargs = {} |
| |
|
| | if "max_num_tokens_to_decode" in kwargs: |
| | forward_kwargs["max_num_tokens_to_decode"] = kwargs[ |
| | "max_num_tokens_to_decode" |
| | ] |
| | if "english_tokens_max_length" in kwargs: |
| | preprocess_kwargs["english_tokens_max_length"] = kwargs[ |
| | "english_tokens_max_length" |
| | ] |
| | if "bio_tokens_max_length" in kwargs: |
| | preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"] |
| |
|
| | return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
| |
|
| | def preprocess( |
| | self, |
| | inputs: dict, |
| | english_tokens_max_length: int = 512, |
| | bio_tokens_max_length: int = 512, |
| | ) -> dict: |
| | english_sequence = inputs["english_sequence"] |
| | dna_sequences = inputs["dna_sequences"] |
| |
|
| | context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " |
| | space = " " |
| | if english_sequence[-1] == " ": |
| | space = "" |
| | english_sequence = context + english_sequence + space + "ASSISTANT:" |
| |
|
| | english_tokens = self.english_tokenizer( |
| | english_sequence, |
| | return_tensors="pt", |
| | padding="max_length", |
| | truncation=True, |
| | max_length=english_tokens_max_length, |
| | ).input_ids |
| | if len(dna_sequences) == 0: |
| | bio_tokens = None |
| | else: |
| | bio_tokens = self.bio_tokenizer( |
| | dna_sequences, |
| | return_tensors="pt", |
| | padding="max_length", |
| | max_length=bio_tokens_max_length, |
| | truncation=True, |
| | ).input_ids.unsqueeze(0) |
| |
|
| | return {"english_tokens": english_tokens, "bio_tokens": bio_tokens} |
| |
|
| | def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict: |
| | english_tokens = model_inputs["english_tokens"].clone() |
| | bio_tokens = model_inputs["bio_tokens"] |
| | if bio_tokens is not None: |
| | bio_tokens = bio_tokens.clone() |
| | projected_bio_embeddings = None |
| |
|
| | actual_num_steps = 0 |
| | with torch.no_grad(): |
| | for _ in range(max_num_tokens_to_decode): |
| | |
| | if ( |
| | self.english_tokenizer.pad_token_id |
| | not in english_tokens[0].cpu().numpy() |
| | ): |
| | break |
| |
|
| | |
| | outs = self.model( |
| | multi_omics_tokens_ids=(english_tokens, bio_tokens), |
| | projection_english_tokens_ids=english_tokens, |
| | projected_bio_embeddings=projected_bio_embeddings, |
| | ) |
| | projected_bio_embeddings = outs["projected_bio_embeddings"] |
| | logits = outs["logits"].detach().cpu().numpy() |
| |
|
| | |
| | first_idx_pad_token = np.where( |
| | english_tokens[0].cpu() == self.english_tokenizer.pad_token_id |
| | )[0][0] |
| | predicted_token = np.argmax(logits[0, first_idx_pad_token - 1]) |
| |
|
| | |
| | if predicted_token == self.english_tokenizer.eos_token_id: |
| | break |
| | else: |
| | english_tokens[0, first_idx_pad_token] = predicted_token |
| | actual_num_steps += 1 |
| |
|
| | |
| | idx_begin_generation = np.where( |
| | model_inputs["english_tokens"][0].cpu() |
| | == self.english_tokenizer.pad_token_id |
| | )[0][0] |
| |
|
| | |
| | generated_tokens = english_tokens[ |
| | 0, idx_begin_generation : idx_begin_generation + actual_num_steps |
| | ] |
| |
|
| | return { |
| | "generated_tokens": generated_tokens, |
| | } |
| |
|
| | def postprocess(self, model_outputs: dict) -> str: |
| | generated_tokens = model_outputs["generated_tokens"] |
| | generated_sequence: str = self.english_tokenizer.decode(generated_tokens) |
| | return generated_sequence |
| |
|