Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from typing import List, Tuple, Dict, Union | |
| from .utils import pad_and_concatenate_dimer | |
| def _pad_matrix_embeds(embeds: List[torch.Tensor], max_len: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # pad and concatenate, return padded embeds and mask | |
| padded_embeds, attention_masks = [], [] | |
| for embed in embeds: | |
| seq_len = embed.size(0) | |
| padding_size = max_len - seq_len | |
| # Create attention mask (1 for real tokens, 0 for padding) | |
| attention_mask = torch.ones(max_len, dtype=torch.long) | |
| if padding_size > 0: | |
| attention_mask[seq_len:] = 0 | |
| # Pad along the sequence dimension (dim=0) | |
| padding = torch.zeros((padding_size, embed.size(1)), dtype=embed.dtype) | |
| padded_embed = torch.cat((embed, padding), dim=0) | |
| else: | |
| padded_embed = embed | |
| padded_embeds.append(padded_embed) | |
| attention_masks.append(attention_mask) | |
| return torch.stack(padded_embeds), torch.stack(attention_masks) | |
| class StringCollator: | |
| def __init__(self, tokenizer, **kwargs): | |
| self.tokenizer = tokenizer | |
| def __call__(self, batch: Tuple[List[str], List[str]]) -> Dict[str, torch.Tensor]: | |
| batch = self.tokenizer(batch, | |
| padding='longest', | |
| return_tensors='pt', | |
| add_special_tokens=True) | |
| return batch | |
| class StringLabelsCollator: | |
| def __init__(self, tokenizer, task_type='regression', tokenwise=False, **kwargs): | |
| self.tokenizer = tokenizer | |
| self.task_type = task_type | |
| self.tokenwise = tokenwise | |
| def __call__(self, batch: List[Tuple[str, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| seqs = [ex[0] for ex in batch] | |
| labels = [ex[1] for ex in batch] | |
| # Tokenize the sequences | |
| batch_encoding = self.tokenizer( | |
| seqs, | |
| padding='longest', | |
| truncation=False, | |
| return_tensors='pt', | |
| add_special_tokens=True | |
| ) | |
| # Handle labels based on tokenwise flag | |
| if self.tokenwise: | |
| # For token-wise labels, we need to pad to match the tokenized sequence length | |
| attention_mask = batch_encoding['attention_mask'] | |
| lengths = [torch.sum(attention_mask[i]).item() for i in range(len(batch))] | |
| max_length = max(lengths) | |
| padded_labels = [] | |
| for label in labels: | |
| if not isinstance(label, torch.Tensor): | |
| label = torch.tensor(label) | |
| label = label.flatten() | |
| padding_size = max_length - len(label) | |
| # Pad or truncate labels to match tokenized sequence length | |
| if padding_size > 0: | |
| # Pad with -100 (ignored by loss functions) | |
| padding = torch.full((padding_size,), -100, dtype=label.dtype) | |
| padded_label = torch.cat((label, padding)) | |
| else: | |
| padded_label = label[:max_length] | |
| padded_labels.append(padded_label) | |
| # Stack all padded labels | |
| batch_encoding['labels'] = torch.stack(padded_labels) | |
| else: | |
| # For sequence-level labels, just stack them | |
| batch_encoding['labels'] = torch.stack([torch.tensor(ex[1]) for ex in batch]) | |
| if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: | |
| batch_encoding['labels'] = batch_encoding['labels'].float() | |
| else: | |
| batch_encoding['labels'] = batch_encoding['labels'].long() | |
| return batch_encoding | |
| class EmbedsLabelsCollator: | |
| def __init__(self, full=False, task_type='regression', tokenwise=False, **kwargs): | |
| self.full = full | |
| self.task_type = task_type | |
| self.tokenwise = tokenwise | |
| def __call__(self, batch: List[Tuple[torch.Tensor, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| if self.full: | |
| embeds = [ex[0] for ex in batch] | |
| labels = [ex[1] for ex in batch] | |
| # Find max sequence length for padding | |
| max_length = max(embed.size(0) for embed in embeds) | |
| embeds, attention_mask = _pad_matrix_embeds(embeds, max_length) | |
| # Pad labels | |
| if self.tokenwise: | |
| padded_labels = [] | |
| for label in labels: | |
| if not isinstance(label, torch.Tensor): | |
| label = torch.tensor(label) | |
| label = label.flatten() | |
| padding_size = max_length - len(label) | |
| if padding_size > 0: | |
| # Use -100 as padding value for labels (ignored by loss functions) | |
| padding = torch.full((padding_size,), -100, dtype=label.dtype) | |
| padded_label = torch.cat((label, padding)) | |
| else: | |
| padded_label = label[:max_length] | |
| padded_labels.append(padded_label) | |
| else: | |
| padded_labels = labels | |
| labels = torch.stack(padded_labels) | |
| if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: | |
| labels = labels.float() | |
| else: | |
| labels = labels.long() | |
| return { | |
| 'embeddings': embeds, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels, | |
| } | |
| else: | |
| embeds = torch.stack([ex[0] for ex in batch]) | |
| labels = torch.stack([ex[1] for ex in batch]) | |
| if self.task_type in ['multilabel', 'regression', 'sigmoid_regression']: | |
| labels = labels.float() | |
| else: | |
| labels = labels.long() | |
| return { | |
| 'embeddings': embeds, | |
| 'labels': labels | |
| } | |
| class PairCollator_input_ids: | |
| def __init__(self, tokenizer, **kwargs): | |
| self.tokenizer = tokenizer | |
| def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| seqs_a, seqs_b, labels = zip(*batch) | |
| labels = torch.tensor(labels, dtype=torch.float) | |
| tokenized = self.tokenizer( | |
| seqs_a, seqs_b, | |
| padding='longest', | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': tokenized['input_ids'], | |
| 'attention_mask': tokenized['attention_mask'], | |
| 'labels': labels | |
| } | |
| class PairCollator_ab: | |
| def __init__(self, tokenizer, **kwargs): | |
| self.tokenizer = tokenizer | |
| def __call__(self, batch: List[Tuple[str, str, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| seqs_a, seqs_b, labels = zip(*batch) | |
| labels = torch.tensor(labels, dtype=torch.float) | |
| tokenized_a = self.tokenizer( | |
| seqs_a, | |
| padding='longest', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| tokenized_b = self.tokenizer( | |
| seqs_b, | |
| padding='longest', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids_a': tokenized_a['input_ids'], | |
| 'input_ids_b': tokenized_b['input_ids'], | |
| 'attention_mask_a': tokenized_a['attention_mask'], | |
| 'attention_mask_b': tokenized_b['attention_mask'], | |
| 'labels': labels | |
| } | |
| class PairEmbedsLabelsCollator: | |
| def __init__(self, full=False, add_token_ids=False, **kwargs): | |
| self.full = full | |
| self.add_token_ids = add_token_ids | |
| def __call__(self, batch: List[Tuple[torch.Tensor, torch.Tensor, Union[float, int]]]) -> Dict[str, torch.Tensor]: | |
| if self.full: | |
| embeds_a = [ex[0] for ex in batch] | |
| embeds_b = [ex[1] for ex in batch] | |
| max_len_a = max(embed.size(0) for embed in embeds_a) | |
| max_len_b = max(embed.size(0) for embed in embeds_b) | |
| embeds_a, attention_mask_a = _pad_matrix_embeds(embeds_a, max_len_a) | |
| embeds_b, attention_mask_b = _pad_matrix_embeds(embeds_b, max_len_b) | |
| embeds, attention_mask = pad_and_concatenate_dimer(embeds_a, embeds_b, attention_mask_a, attention_mask_b) | |
| labels = torch.stack([ex[2] for ex in batch]) | |
| # For tasks requiring token type IDs, provide them so the model knows | |
| # which tokens belong to protein A vs protein B | |
| if self.add_token_ids: | |
| batch_size = embeds.size(0) | |
| max_len = embeds.size(1) | |
| token_type_ids = torch.zeros(batch_size, max_len, dtype=torch.long) | |
| for i in range(batch_size): | |
| a_len = int(attention_mask_a[i].sum().item()) | |
| b_len = int(attention_mask_b[i].sum().item()) | |
| # type 0 for protein A, type 1 for protein B | |
| token_type_ids[i, a_len:a_len + b_len] = 1 | |
| return { | |
| 'embeddings': embeds, | |
| 'attention_mask': attention_mask, | |
| 'token_type_ids': token_type_ids, | |
| 'labels': labels | |
| } | |
| return { | |
| 'embeddings': embeds, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels | |
| } | |
| else: | |
| embeds_a = torch.stack([ex[0] for ex in batch]) | |
| embeds_b = torch.stack([ex[1] for ex in batch]) | |
| labels = torch.stack([ex[2] for ex in batch]) | |
| embeds = torch.cat([embeds_a, embeds_b], dim=-1) | |
| return { | |
| 'embeddings': embeds, | |
| 'labels': labels | |
| } | |
| class OneHotCollator: | |
| def __init__(self, alphabet="ACDEFGHIKLMNPQRSTVWY"): | |
| # Add X for unknown amino acids, and special CLS and EOS tokens | |
| alphabet = alphabet + "X" | |
| alphabet = list(alphabet) | |
| self.mapping = {token: idx for idx, token in enumerate(alphabet)} | |
| def __call__(self, batch): | |
| seqs = [ex[0] for ex in batch] | |
| labels = torch.stack([torch.tensor(ex[1]) for ex in batch]) | |
| # Find the longest sequence in the batch | |
| max_len = max(len(seq) for seq in seqs) | |
| # One-hot encode and pad each sequence | |
| one_hot_tensors, attention_masks = [], [] | |
| for seq in seqs: | |
| seq = list(seq) | |
| # Create one-hot encoding for each sequence (including CLS and EOS) | |
| seq_len = len(seq) | |
| one_hot = torch.zeros(seq_len, len(self.alphabet)) | |
| # Add sequence tokens in the middle | |
| for pos, token in enumerate(seq): | |
| if token in self.mapping: | |
| one_hot[pos, self.mapping[token]] = 1.0 | |
| else: | |
| # For non-canonical amino acids, use the X token | |
| one_hot[pos, self.mapping["X"]] = 1.0 | |
| # Create attention mask (1 for actual tokens, 0 for padding) | |
| attention_mask = torch.ones(seq_len) | |
| # Pad to the max length in this batch | |
| padding_size = max_len - seq_len | |
| if padding_size > 0: | |
| padding = torch.zeros(padding_size, len(self.alphabet)) | |
| one_hot = torch.cat([one_hot, padding], dim=0) | |
| # Add zeros to attention mask for padding | |
| mask_padding = torch.zeros(padding_size) | |
| attention_mask = torch.cat([attention_mask, mask_padding], dim=0) | |
| one_hot_tensors.append(one_hot) | |
| attention_masks.append(attention_mask) | |
| # Stack all tensors in the batch | |
| embeddings = torch.stack(one_hot_tensors) | |
| attention_masks = torch.stack(attention_masks) | |
| return { | |
| 'embeddings': embeddings, | |
| 'attention_mask': attention_masks, | |
| 'labels': labels, | |
| } | |