| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoModel, AutoConfig |
| |
|
| | class CodeEmbedder(nn.Module): |
| | """ |
| | A wrapper around a Transformer model (default: CodeBERT) to produce |
| | dense vector embeddings for code snippets using Mean Pooling. |
| | """ |
| | def __init__(self, model_name_or_path="microsoft/codebert-base", trust_remote_code=False): |
| | super(CodeEmbedder, self).__init__() |
| | self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) |
| | self.encoder = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=trust_remote_code) |
| | |
| | def mean_pooling(self, token_embeddings, attention_mask): |
| | """ |
| | Average the token embeddings, ignoring padding tokens. |
| | """ |
| | |
| | |
| | |
| | |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | |
| | |
| | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
| | |
| | |
| | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| | |
| | return sum_embeddings / sum_mask |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | |
| | outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| | |
| | |
| | |
| | last_hidden_state = outputs.last_hidden_state |
| | |
| | |
| | embeddings = self.mean_pooling(last_hidden_state, attention_mask) |
| | |
| | |
| | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
| | |
| | return embeddings |
| |
|