import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel from .configuration_bert_ffnn import BertFFNNConfig class AttentionPooling(nn.Module): def __init__(self, hidden_size): super().__init__() self.attention = nn.Linear(hidden_size, 1) def forward(self, hidden_states, attention_mask): scores = self.attention(hidden_states).squeeze(-1) scores = scores.masked_fill(attention_mask == 0, -1e9) weights = torch.softmax(scores, dim=-1) return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) class BERT_FFNN(PreTrainedModel): config_class = BertFFNNConfig base_model_prefix = "bert_ffnn" def __init__(self, config): super().__init__(config) self.bert = AutoModel.from_pretrained(config.bert_model_name) self.pooling = config.pooling self.use_layer_norm = config.use_layer_norm if self.pooling == "attention": self.attention_pool = AttentionPooling(self.bert.config.hidden_size) if config.freeze_bert: for p in self.bert.parameters(): p.requires_grad = False elif config.freeze_layers > 0: for layer in self.bert.encoder.layer[:config.freeze_layers]: for p in layer.parameters(): p.requires_grad = False layers = [] in_dim = self.bert.config.hidden_size for h_dim in config.hidden_dims: layers.append(nn.Linear(in_dim, h_dim)) layers.append(nn.ReLU()) if config.use_layer_norm: layers.append(nn.LayerNorm(h_dim)) layers.append(nn.Dropout(config.dropout)) in_dim = h_dim layers.append(nn.Linear(in_dim, config.output_dim)) self.classifier = nn.Sequential(*layers) self.post_init() def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) if self.pooling == "mean": mask = attention_mask.unsqueeze(-1).float() sum_emb = (outputs.last_hidden_state * mask).sum(1) features = sum_emb / mask.sum(1).clamp(min=1e-9) elif self.pooling == "max": mask = attention_mask.unsqueeze(-1).float() masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf')) features, _ = masked_emb.max(dim=1) elif self.pooling == "attention": features = self.attention_pool(outputs.last_hidden_state, attention_mask) else: # CLS pooling features = ( outputs.pooler_output if getattr(outputs, "pooler_output", None) is not None else outputs.last_hidden_state[:, 0] ) logits = self.classifier(features) return logits