File size: 2,862 Bytes
e877829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .configuration_bert_ffnn import BertFFNNConfig


class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states, attention_mask):
        scores = self.attention(hidden_states).squeeze(-1)
        scores = scores.masked_fill(attention_mask == 0, -1e9)
        weights = torch.softmax(scores, dim=-1)
        return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)


class BERT_FFNN(PreTrainedModel):
    config_class = BertFFNNConfig
    base_model_prefix = "bert_ffnn"

    def __init__(self, config):
        super().__init__(config)
        self.bert = AutoModel.from_pretrained(config.bert_model_name)
        self.pooling = config.pooling
        self.use_layer_norm = config.use_layer_norm

        if self.pooling == "attention":
            self.attention_pool = AttentionPooling(self.bert.config.hidden_size)
        if config.freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False
        elif config.freeze_layers > 0:
            for layer in self.bert.encoder.layer[:config.freeze_layers]:
                for p in layer.parameters():
                    p.requires_grad = False

        layers = []
        in_dim = self.bert.config.hidden_size
        for h_dim in config.hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            if config.use_layer_norm:
                layers.append(nn.LayerNorm(h_dim))
            layers.append(nn.Dropout(config.dropout))
            in_dim = h_dim

        layers.append(nn.Linear(in_dim, config.output_dim))
        self.classifier = nn.Sequential(*layers)

        self.post_init()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        if self.pooling == "mean":
            mask = attention_mask.unsqueeze(-1).float()
            sum_emb = (outputs.last_hidden_state * mask).sum(1)
            features = sum_emb / mask.sum(1).clamp(min=1e-9)
        elif self.pooling == "max":
            mask = attention_mask.unsqueeze(-1).float()
            masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf'))
            features, _ = masked_emb.max(dim=1)
        elif self.pooling == "attention":
            features = self.attention_pool(outputs.last_hidden_state, attention_mask)
        else:  # CLS pooling
            features = (
                outputs.pooler_output
                if getattr(outputs, "pooler_output", None) is not None
                else outputs.last_hidden_state[:, 0]
            )

        logits = self.classifier(features)
        return logits