SONAR-Text-to-Text / modeling.py
tutur90's picture
Upload modeling.py
aa5ca32 verified
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers import (
M2M100Config,)
from transformers.models.m2m_100.modeling_m2m_100 import (
M2M100Encoder,
M2M100ScaledWordEmbedding,
M2M100ForConditionalGeneration,
M2M100Model,
shift_tokens_right,
logger)
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.utils import auto_docstring
from torch.nn import CrossEntropyLoss
class Pooling(nn.Module):
"""
Pooling layer for sequence representations.
Supports multiple pooling strategies: mean, cls, last, max, and none.
"""
def __init__(self, pooling_type: str = "mean"):
super().__init__()
valid_types = {"mean", "cls", "last", "max", "none"}
if pooling_type not in valid_types:
raise ValueError(f"pooling_type must be one of {valid_types}, got {pooling_type}")
self.pooling_type = pooling_type
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Apply pooling to hidden states.
Args:
hidden_states: Tensor of shape (batch_size, seq_len, hidden_size)
attention_mask: Tensor of shape (batch_size, seq_len), values in {0, 1}
Returns:
Pooled tensor of shape (batch_size, 1, hidden_size) or (batch_size, seq_len, hidden_size) for none
"""
if self.pooling_type == "none":
return hidden_states
if self.pooling_type == "cls":
return hidden_states[:, 0, :].unsqueeze(1)
elif self.pooling_type == "last":
return hidden_states[:, -1, :].unsqueeze(1)
elif self.pooling_type in ["mean", "max"]:
if attention_mask is None:
raise ValueError(f"attention_mask is required for {self.pooling_type} pooling")
# Expand attention mask to match hidden_states dimensions
mask = attention_mask.unsqueeze(-1)
if self.pooling_type == "mean":
# Apply mask and compute mean over valid tokens
masked_hidden = hidden_states * mask
sum_hidden = masked_hidden.sum(dim=1, keepdim=True)
sum_mask = mask.sum(dim=1, keepdim=True)
return sum_hidden / sum_mask
elif self.pooling_type == "max":
# Apply mask (set masked positions to large negative value)
masked_hidden = hidden_states.masked_fill(mask == 0, float('-inf'))
return masked_hidden.max(dim=1, keepdim=True)[0]
class SONARTextEncoder(M2M100Encoder):
"""
Transformer encoder with pooling capabilities.
Inherits from M2M100Encoder and adds configurable pooling functionality.
Args:
config: M2M100Config with optional pooling_type attribute
embed_tokens: Optional embedding layer
"""
def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None):
super().__init__(config, embed_tokens)
# Initialize pooling layer
pooling_type = getattr(config, 'pooling_type', 'mean')
self.pooling = Pooling(pooling_type)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
"""
Forward pass with optional pooling.
Args:
input_ids: Input token ids
attention_mask: Attention mask for padding tokens
head_mask: Mask for attention heads
inputs_embeds: Pre-computed embeddings
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return all hidden states
return_dict: Whether to return ModelOutput object
pool: Pooling strategy override (if None, uses config default)
Returns:
Model output with pooled representations
"""
# Get encoder output
encoder_output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Extract hidden states
if return_dict:
hidden_states = encoder_output.last_hidden_state
else:
hidden_states = encoder_output[0]
pooled_output = self.pooling(hidden_states, attention_mask)
if return_dict:
encoder_output.last_hidden_state = pooled_output
else:
encoder_output = (pooled_output,) + encoder_output[1:]
return encoder_output
class SONARModel(M2M100Model):
"""SONAR model based on M2M100."""
def __init__(self, config: M2M100Config):
super().__init__(config)
self.encoder = SONARTextEncoder(config, self.shared)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
return_logits: Optional[bool] = False,
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
M2M100 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
"""
outputs = super().forward(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if return_logits:
lm_logits = self.decoder.lm_head(outputs[0])
if not return_dict:
outputs = (lm_logits,) + outputs[1:]
else:
outputs.last_hidden_state = lm_logits
return outputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
M2M100 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
if attention_mask is not None:
if (encoder_outputs[0].size(1) != 1 and encoder_outputs[0].dim() == 3):
logger.warning_once(
f"Encoder is not pooled"
)
encoder_attention_mask = attention_mask
else:
encoder_attention_mask = attention_mask[:, :1]
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=encoder_attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class SONARForText2Text(M2M100ForConditionalGeneration):
"""SONAR model for conditional generation tasks."""
# _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: M2M100Config):
super().__init__(config)
self.model = SONARModel(config)
self.cross_entropy_loss = CrossEntropyLoss(
label_smoothing=0.1,
ignore_index=-100
)
self.mse_loss = nn.MSELoss()
self.mse_ratio = getattr(config, 'mse_ratio', 0.2)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
target_ids: Optional[torch.LongTensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
mse_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
M2M100 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example Translation:
```python
>>> from transformers import AutoTokenizer, M2M100ForConditionalGeneration
>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/m2m100_418M")
>>> text_to_translate = "Life is like a box of chocolates"
>>> model_inputs = tokenizer(text_to_translate, return_tensors="pt")
>>> # translate to French
>>> gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("fr"))
>>> print(tokenizer.batch_decode(gen_tokens, skip_special_tokens=True))
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
lm_logits = self.lm_head(outputs[0])
masked_lm_loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
masked_lm_loss = self.cross_entropy_loss(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
# print(f"Cross Entropy Loss: {masked_lm_loss if masked_lm_loss is not None else 'N/A'}")
masked_lm_loss = masked_lm_loss.mean()
if mse_mask is not None and target_ids is None:
mse_mask = mse_mask.view(-1, 1, 1).to(outputs.encoder_last_hidden_state.device)
encoder_outputs = outputs.encoder_last_hidden_state.squeeze()
batch_size = encoder_outputs.size(0)
# Reshape to pair structure: [batch//2, 2, seq_len, hidden]
paired = encoder_outputs[:batch_size//2*2].view(batch_size//2, 2, *encoder_outputs.shape[1:]) * mse_mask
mse_loss = self.mse_loss(paired[:, 0], paired[:, 1])
masked_lm_loss += self.mse_ratio * mse_loss
if target_ids is not None and labels is not None:
target_ids = target_ids.to(lm_logits.device)
target_encoder_outputs = self.model.encoder(
input_ids=target_ids,
attention_mask=target_attention_mask,
return_dict=return_dict,
)
mse_loss = self.mse_loss(outputs.encoder_last_hidden_state, target_encoder_outputs.last_hidden_state)
masked_lm_loss += self.mse_ratio * mse_loss
# print(f"Masked LM Loss: {masked_lm_loss.item() if masked_lm_loss is not None else 'N/A'}")
# print(f"MSE Loss: {mse_loss.item() if mse_loss is not None else 'N/A'}")
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@classmethod
def from_m2m100_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
generation_config = model.generation_config
generation_config.early_stopping = True
generation_config.num_beams = 5
generation_config.max_length = 500
config = model.config
config.pooling_type = getattr(config, 'pooling_type', 'mean')
sonar_model = cls(model.config)
sonar_model.load_state_dict(model.state_dict(), strict=False)
sonar_model.generation_config = generation_config
return sonar_model