|
|
|
|
|
""" |
|
|
Dual-head multi-label PyTorch training script for mmBERT-base. |
|
|
Two classification heads: onderwerp (topic) and beleving (experience) with dynamic label counts. |
|
|
Uses combined F1+BCE loss with weight α (configurable balance). |
|
|
Features: learnable thresholds, warmup + cosine LR, gradient clipping. |
|
|
mmBERT: Modern multilingual encoder (1800+ languages, 2x faster than XLM-R). |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import random |
|
|
import wandb |
|
|
from rd_dataset_loader import load_rd_wim_dataset |
|
|
|
|
|
|
|
|
|
|
|
def prob_to_logit(p: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: |
|
|
"""Convert probabilities to logits (inverse sigmoid). Numerically stable.""" |
|
|
p = torch.clamp(p, eps, 1 - eps) |
|
|
return torch.log(p / (1 - p)) |
|
|
|
|
|
|
|
|
def logit_to_prob(l: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert logits to probabilities using sigmoid.""" |
|
|
return torch.sigmoid(l) |
|
|
|
|
|
|
|
|
|
|
|
def get_device(): |
|
|
if torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
print("Using MPS (Apple Silicon) for acceleration") |
|
|
elif torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
print("Using CUDA GPU") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
print("Using CPU") |
|
|
return device |
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
|
"""Set random seeds for reproducibility across torch, numpy, and Python random.""" |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
class mmBERTDualHead(nn.Module): |
|
|
""" |
|
|
mmBERT with two classification heads for multi-task learning. |
|
|
Shared encoder with separate heads for onderwerp and beleving. |
|
|
Optionally includes learnable thresholds for each head. |
|
|
""" |
|
|
def __init__(self, model_name, num_onderwerp, num_beleving, dropout, initial_threshold, use_thresholds: bool = True): |
|
|
super().__init__() |
|
|
self.use_thresholds = use_thresholds |
|
|
|
|
|
|
|
|
self.encoder = AutoModel.from_pretrained(model_name) |
|
|
hidden_size = self.encoder.config.hidden_size |
|
|
|
|
|
|
|
|
self.onderwerp_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.Dropout(dropout), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_size, num_onderwerp) |
|
|
) |
|
|
|
|
|
|
|
|
self.beleving_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
nn.Dropout(dropout), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_size, num_beleving) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.onderwerp_tau_logit = None |
|
|
self.beleving_tau_logit = None |
|
|
if self.use_thresholds: |
|
|
init_logit = prob_to_logit(torch.tensor(initial_threshold)) |
|
|
self.onderwerp_tau_logit = nn.Parameter(torch.full((num_onderwerp,), init_logit)) |
|
|
self.beleving_tau_logit = nn.Parameter(torch.full((num_beleving,), init_logit)) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
|
|
outputs = self.encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
pooled_output = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
|
|
|
onderwerp_logits = self.onderwerp_head(pooled_output) |
|
|
beleving_logits = self.beleving_head(pooled_output) |
|
|
|
|
|
return onderwerp_logits, beleving_logits |
|
|
|
|
|
|
|
|
class DutchDualLabelDataset(Dataset): |
|
|
"""Dataset for dual-label classification (onderwerp + beleving).""" |
|
|
|
|
|
def __init__(self, texts, onderwerp_labels, beleving_labels, tokenizer, max_length): |
|
|
self.texts = texts |
|
|
self.onderwerp_labels = onderwerp_labels |
|
|
self.beleving_labels = beleving_labels |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.texts) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
text = self.texts[idx] |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].squeeze(), |
|
|
'attention_mask': encoding['attention_mask'].squeeze(), |
|
|
'onderwerp_labels': torch.tensor(self.onderwerp_labels[idx], dtype=torch.float), |
|
|
'beleving_labels': torch.tensor(self.beleving_labels[idx], dtype=torch.float) |
|
|
} |
|
|
|
|
|
|
|
|
def calculate_soft_f1(logits, labels, logit_threshold=None, temperature=1.0): |
|
|
""" |
|
|
Calculate differentiable F1 score using sigmoid approximation. |
|
|
|
|
|
If logit_threshold is None: y_soft = sigmoid(logits * T) |
|
|
Else: y_soft = sigmoid((logits - logit_threshold) * T) |
|
|
|
|
|
Rationale: |
|
|
- With thresholds ON, Soft-F1 learns per-class decision boundaries in logit space. |
|
|
- With thresholds OFF, we follow POLA: a single, obvious source (head logits). |
|
|
|
|
|
Args: |
|
|
logits: Model predictions (before sigmoid) |
|
|
labels: True labels (multi-hot encoded) |
|
|
logit_threshold: Optional decision threshold in LOGIT space (None = no shift) |
|
|
temperature: Sharpness of sigmoid approximation |
|
|
|
|
|
Returns: |
|
|
soft_f1: Differentiable F1 score |
|
|
""" |
|
|
|
|
|
if logit_threshold is None: |
|
|
shifted = logits * temperature |
|
|
else: |
|
|
shifted = (logits - logit_threshold) * temperature |
|
|
|
|
|
|
|
|
y_pred_soft = torch.sigmoid(shifted) |
|
|
|
|
|
|
|
|
TP = (y_pred_soft * labels).sum(dim=-1) |
|
|
FP = (y_pred_soft * (1 - labels)).sum(dim=-1) |
|
|
FN = ((1 - y_pred_soft) * labels).sum(dim=-1) |
|
|
|
|
|
|
|
|
eps = 1e-8 |
|
|
precision = TP / (TP + FP + eps) |
|
|
recall = TP / (TP + FN + eps) |
|
|
f1 = 2 * precision * recall / (precision + recall + eps) |
|
|
|
|
|
return f1.mean() |
|
|
|
|
|
|
|
|
def evaluate(model, val_texts, val_onderwerp, val_beleving, tokenizer, device, |
|
|
onderwerp_names, beleving_names, num_samples, max_length): |
|
|
""" |
|
|
Evaluate model on validation set and return metrics. |
|
|
|
|
|
Args: |
|
|
model: The trained model |
|
|
val_texts: List of validation texts |
|
|
val_onderwerp: Validation onderwerp labels |
|
|
val_beleving: Validation beleving labels |
|
|
tokenizer: Tokenizer for encoding text |
|
|
device: Device to run evaluation on |
|
|
onderwerp_names: List of onderwerp label names |
|
|
beleving_names: List of beleving label names |
|
|
num_samples: Number of samples to evaluate (None = all) |
|
|
max_length: Max sequence length |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary containing all evaluation metrics |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
if num_samples is None: |
|
|
num_samples = len(val_texts) |
|
|
else: |
|
|
num_samples = min(num_samples, len(val_texts)) |
|
|
|
|
|
|
|
|
onderwerp_correct = np.zeros(len(onderwerp_names)) |
|
|
onderwerp_total = np.zeros(len(onderwerp_names)) |
|
|
beleving_correct = np.zeros(len(beleving_names)) |
|
|
beleving_total = np.zeros(len(beleving_names)) |
|
|
|
|
|
|
|
|
onderwerp_tp = 0 |
|
|
onderwerp_fp = 0 |
|
|
onderwerp_fn = 0 |
|
|
beleving_tp = 0 |
|
|
beleving_fp = 0 |
|
|
beleving_fn = 0 |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for i in range(num_samples): |
|
|
|
|
|
encoding = tokenizer( |
|
|
val_texts[i], |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = encoding['input_ids'].to(device) |
|
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
|
|
|
onderwerp_logits, beleving_logits = model(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
onderwerp_probs = torch.sigmoid(onderwerp_logits) |
|
|
beleving_probs = torch.sigmoid(beleving_logits) |
|
|
|
|
|
|
|
|
if model.use_thresholds: |
|
|
tau_on = logit_to_prob(model.onderwerp_tau_logit) |
|
|
tau_be = logit_to_prob(model.beleving_tau_logit) |
|
|
else: |
|
|
|
|
|
tau_on = torch.full_like(onderwerp_probs[0], 0.5) |
|
|
tau_be = torch.full_like(beleving_probs[0], 0.5) |
|
|
|
|
|
onderwerp_pred = (onderwerp_probs > tau_on).squeeze().cpu().numpy() |
|
|
beleving_pred = (beleving_probs > tau_be).squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
onderwerp_true = val_onderwerp[i] |
|
|
beleving_true = val_beleving[i] |
|
|
|
|
|
|
|
|
onderwerp_tp += ((onderwerp_pred == 1) & (onderwerp_true == 1)).sum() |
|
|
onderwerp_fp += ((onderwerp_pred == 1) & (onderwerp_true == 0)).sum() |
|
|
onderwerp_fn += ((onderwerp_pred == 0) & (onderwerp_true == 1)).sum() |
|
|
|
|
|
beleving_tp += ((beleving_pred == 1) & (beleving_true == 1)).sum() |
|
|
beleving_fp += ((beleving_pred == 1) & (beleving_true == 0)).sum() |
|
|
beleving_fn += ((beleving_pred == 0) & (beleving_true == 1)).sum() |
|
|
|
|
|
|
|
|
for j in range(len(onderwerp_names)): |
|
|
if onderwerp_pred[j] == onderwerp_true[j]: |
|
|
onderwerp_correct[j] += 1 |
|
|
onderwerp_total[j] += 1 |
|
|
|
|
|
for j in range(len(beleving_names)): |
|
|
if beleving_pred[j] == beleving_true[j]: |
|
|
beleving_correct[j] += 1 |
|
|
beleving_total[j] += 1 |
|
|
|
|
|
|
|
|
epsilon = 1e-8 |
|
|
onderwerp_precision = onderwerp_tp / (onderwerp_tp + onderwerp_fp + epsilon) |
|
|
onderwerp_recall = onderwerp_tp / (onderwerp_tp + onderwerp_fn + epsilon) |
|
|
onderwerp_f1_score = 2 * onderwerp_precision * onderwerp_recall / (onderwerp_precision + onderwerp_recall + epsilon) |
|
|
|
|
|
beleving_precision = beleving_tp / (beleving_tp + beleving_fp + epsilon) |
|
|
beleving_recall = beleving_tp / (beleving_tp + beleving_fn + epsilon) |
|
|
beleving_f1_score = 2 * beleving_precision * beleving_recall / (beleving_precision + beleving_recall + epsilon) |
|
|
|
|
|
|
|
|
onderwerp_acc = onderwerp_correct.sum() / onderwerp_total.sum() |
|
|
beleving_acc = beleving_correct.sum() / beleving_total.sum() |
|
|
|
|
|
|
|
|
if model.use_thresholds: |
|
|
onderwerp_thresh_mean = logit_to_prob(model.onderwerp_tau_logit).mean().item() |
|
|
onderwerp_thresh_min = logit_to_prob(model.onderwerp_tau_logit).min().item() |
|
|
onderwerp_thresh_max = logit_to_prob(model.onderwerp_tau_logit).max().item() |
|
|
onderwerp_thresh_std = logit_to_prob(model.onderwerp_tau_logit).std().item() |
|
|
beleving_thresh_mean = logit_to_prob(model.beleving_tau_logit).mean().item() |
|
|
beleving_thresh_min = logit_to_prob(model.beleving_tau_logit).min().item() |
|
|
beleving_thresh_max = logit_to_prob(model.beleving_tau_logit).max().item() |
|
|
beleving_thresh_std = logit_to_prob(model.beleving_tau_logit).std().item() |
|
|
else: |
|
|
|
|
|
onderwerp_thresh_mean = onderwerp_thresh_min = onderwerp_thresh_max = onderwerp_thresh_std = 0.5 |
|
|
beleving_thresh_mean = beleving_thresh_min = beleving_thresh_max = beleving_thresh_std = 0.5 |
|
|
|
|
|
|
|
|
return { |
|
|
'onderwerp_acc': onderwerp_acc, |
|
|
'onderwerp_precision': onderwerp_precision, |
|
|
'onderwerp_recall': onderwerp_recall, |
|
|
'onderwerp_f1': onderwerp_f1_score, |
|
|
'beleving_acc': beleving_acc, |
|
|
'beleving_precision': beleving_precision, |
|
|
'beleving_recall': beleving_recall, |
|
|
'beleving_f1': beleving_f1_score, |
|
|
'combined_acc': (onderwerp_acc + beleving_acc) / 2, |
|
|
'combined_f1': (onderwerp_f1_score + beleving_f1_score) / 2, |
|
|
'onderwerp_thresh_mean': onderwerp_thresh_mean, |
|
|
'onderwerp_thresh_min': onderwerp_thresh_min, |
|
|
'onderwerp_thresh_max': onderwerp_thresh_max, |
|
|
'onderwerp_thresh_std': onderwerp_thresh_std, |
|
|
'beleving_thresh_mean': beleving_thresh_mean, |
|
|
'beleving_thresh_min': beleving_thresh_min, |
|
|
'beleving_thresh_max': beleving_thresh_max, |
|
|
'beleving_thresh_std': beleving_thresh_std, |
|
|
'num_samples_evaluated': num_samples |
|
|
} |
|
|
|
|
|
|
|
|
def grad_l2_norm(params): |
|
|
""" |
|
|
Calculate L2 norm of gradients safely (avoids Python int→Tensor addition). |
|
|
|
|
|
Args: |
|
|
params: Iterator of parameters (e.g., model.parameters()) |
|
|
|
|
|
Returns: |
|
|
float: L2 norm of all gradients, or 0.0 if no gradients exist |
|
|
""" |
|
|
sq_sum = None |
|
|
for p in params: |
|
|
if p.grad is None: |
|
|
continue |
|
|
g = p.grad |
|
|
val = g.pow(2).sum() |
|
|
sq_sum = val if sq_sum is None else (sq_sum + val) |
|
|
if sq_sum is None: |
|
|
return 0.0 |
|
|
return sq_sum.sqrt().item() |
|
|
|
|
|
|
|
|
def make_opt_sched(model, enc_lr, thr_lr, total_steps, warmup_ratio, eta_min): |
|
|
""" |
|
|
Create optimizer+scheduler for training. |
|
|
Optimizer has 1-2 param groups: [0]=encoder+heads, [1]=thresholds (optional). |
|
|
""" |
|
|
|
|
|
encoder_params = [p for n, p in model.named_parameters() |
|
|
if not (model.use_thresholds and 'tau_logit' in n)] |
|
|
param_groups = [{"params": encoder_params, "lr": enc_lr, "weight_decay": 0.0}] |
|
|
|
|
|
|
|
|
if model.use_thresholds: |
|
|
thr_params = [model.onderwerp_tau_logit, model.beleving_tau_logit] |
|
|
param_groups.append({"params": thr_params, "lr": thr_lr, "weight_decay": 0.0}) |
|
|
|
|
|
optimizer = torch.optim.AdamW(param_groups) |
|
|
|
|
|
|
|
|
warmup_steps = min(max(1, int(warmup_ratio * total_steps)), max(1, total_steps - 1)) |
|
|
warmup = LinearLR(optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps) |
|
|
cosine = CosineAnnealingLR(optimizer, T_max=max(1, total_steps - warmup_steps), eta_min=eta_min) |
|
|
scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[warmup_steps]) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
def run_epochs(model, tokenizer, train_loader, val_texts, val_onderwerp, val_beleving, |
|
|
onderwerp_names, beleving_names, device, |
|
|
*, start_epoch, end_epoch, phase_name="train", |
|
|
optimizer, scheduler, temperature, alpha, |
|
|
max_length, global_step): |
|
|
""" |
|
|
Run training for a range of epochs. |
|
|
|
|
|
Args: |
|
|
model: The model to train |
|
|
tokenizer: Tokenizer for text encoding |
|
|
train_loader: DataLoader for training batches |
|
|
val_texts, val_onderwerp, val_beleving: Validation data |
|
|
onderwerp_names, beleving_names: Label names |
|
|
device: Device to train on |
|
|
start_epoch: Starting epoch (inclusive) |
|
|
end_epoch: Ending epoch (exclusive) |
|
|
phase_name: Name for logging (default: "train") |
|
|
optimizer: Optimizer |
|
|
scheduler: LR scheduler |
|
|
temperature: Soft-F1 temperature |
|
|
alpha: Loss weighting (F1 vs BCE) |
|
|
max_length: Max sequence length |
|
|
global_step: Starting global step counter |
|
|
|
|
|
Returns: |
|
|
Updated global_step |
|
|
""" |
|
|
num_epochs = end_epoch - start_epoch |
|
|
phase_total_steps = max(1, len(train_loader) * num_epochs) |
|
|
|
|
|
model.train() |
|
|
|
|
|
for epoch in range(start_epoch, end_epoch): |
|
|
total_loss = 0 |
|
|
total_onderwerp_f1 = 0 |
|
|
total_beleving_f1 = 0 |
|
|
total_bce_loss = 0 |
|
|
total_f1_loss = 0 |
|
|
num_batches = 0 |
|
|
|
|
|
print(f"\n[{phase_name.upper()}] Epoch {epoch + 1}/{end_epoch}") |
|
|
print("-" * 40) |
|
|
|
|
|
for batch_idx, batch in enumerate(train_loader): |
|
|
|
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
onderwerp_labels = batch['onderwerp_labels'].to(device) |
|
|
beleving_labels = batch['beleving_labels'].to(device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
onderwerp_logits, beleving_logits = model(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
onderwerp_f1 = calculate_soft_f1( |
|
|
onderwerp_logits, onderwerp_labels, |
|
|
model.onderwerp_tau_logit if model.use_thresholds else None, |
|
|
temperature |
|
|
) |
|
|
beleving_f1 = calculate_soft_f1( |
|
|
beleving_logits, beleving_labels, |
|
|
model.beleving_tau_logit if model.use_thresholds else None, |
|
|
temperature |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bce_onderwerp = F.binary_cross_entropy_with_logits(onderwerp_logits, onderwerp_labels) |
|
|
bce_beleving = F.binary_cross_entropy_with_logits(beleving_logits, beleving_labels) |
|
|
|
|
|
|
|
|
f1_loss = (1 - onderwerp_f1) + (1 - beleving_f1) |
|
|
bce_loss = bce_onderwerp + bce_beleving |
|
|
loss = alpha * (f1_loss / 2) + (1 - alpha) * (bce_loss / 2) |
|
|
|
|
|
|
|
|
if batch_idx % 20 == 0: |
|
|
with torch.no_grad(): |
|
|
|
|
|
onderwerp_probs = torch.sigmoid(onderwerp_logits) |
|
|
beleving_probs = torch.sigmoid(beleving_logits) |
|
|
if model.use_thresholds: |
|
|
tau_on = logit_to_prob(model.onderwerp_tau_logit) |
|
|
tau_be = logit_to_prob(model.beleving_tau_logit) |
|
|
else: |
|
|
tau_on = torch.full_like(onderwerp_probs[0], 0.5) |
|
|
tau_be = torch.full_like(beleving_probs[0], 0.5) |
|
|
onderwerp_pred = (onderwerp_probs > tau_on).float() |
|
|
beleving_pred = (beleving_probs > tau_be).float() |
|
|
|
|
|
|
|
|
lrs = scheduler.get_last_lr() |
|
|
encoder_head_lr = lrs[0] |
|
|
threshold_lr = lrs[1] if len(lrs) > 1 else None |
|
|
|
|
|
|
|
|
if model.use_thresholds: |
|
|
onderwerp_thresh_mean = logit_to_prob(model.onderwerp_tau_logit).mean().item() |
|
|
onderwerp_thresh_min = logit_to_prob(model.onderwerp_tau_logit).min().item() |
|
|
onderwerp_thresh_max = logit_to_prob(model.onderwerp_tau_logit).max().item() |
|
|
beleving_thresh_mean = logit_to_prob(model.beleving_tau_logit).mean().item() |
|
|
beleving_thresh_min = logit_to_prob(model.beleving_tau_logit).min().item() |
|
|
beleving_thresh_max = logit_to_prob(model.beleving_tau_logit).max().item() |
|
|
else: |
|
|
onderwerp_thresh_mean = onderwerp_thresh_min = onderwerp_thresh_max = 0.5 |
|
|
beleving_thresh_mean = beleving_thresh_min = beleving_thresh_max = 0.5 |
|
|
|
|
|
print(f" Batch {batch_idx + 1} | Step {global_step + 1}/{phase_total_steps}:") |
|
|
if threshold_lr is not None: |
|
|
print(f" Total loss: {loss.item():.4f} (α={alpha} F1 + {1-alpha} BCE) | LR: enc_head={encoder_head_lr:.2e} thresh={threshold_lr:.2e}") |
|
|
else: |
|
|
print(f" Total loss: {loss.item():.4f} (α={alpha} F1 + {1-alpha} BCE) | LR: enc_head={encoder_head_lr:.2e}") |
|
|
print(f" F1 loss: {(f1_loss/2).item():.4f} | BCE loss: {(bce_loss/2).item():.4f}") |
|
|
print(f" Onderwerp F1: {onderwerp_f1.item():.4f} | BCE: {bce_onderwerp.item():.4f} | Thresh: {onderwerp_thresh_mean:.3f} [{onderwerp_thresh_min:.3f}-{onderwerp_thresh_max:.3f}]") |
|
|
print(f" Beleving F1: {beleving_f1.item():.4f} | BCE: {bce_beleving.item():.4f} | Thresh: {beleving_thresh_mean:.3f} [{beleving_thresh_min:.3f}-{beleving_thresh_max:.3f}]") |
|
|
print(f" Onderwerp preds: {int(onderwerp_pred.sum())} / {int(onderwerp_labels.sum())} true") |
|
|
print(f" Beleving preds: {int(beleving_pred.sum())} / {int(beleving_labels.sum())} true") |
|
|
|
|
|
|
|
|
log_dict = { |
|
|
"phase": phase_name, |
|
|
"train/loss": loss.item(), |
|
|
"train/f1_loss": (f1_loss / 2).item(), |
|
|
"train/bce_loss": (bce_loss / 2).item(), |
|
|
"train/onderwerp_f1": onderwerp_f1.item(), |
|
|
"train/onderwerp_bce": bce_onderwerp.item(), |
|
|
"train/beleving_f1": beleving_f1.item(), |
|
|
"train/beleving_bce": bce_beleving.item(), |
|
|
"train/encoder_head_lr": encoder_head_lr, |
|
|
"train/onderwerp_threshold_mean": onderwerp_thresh_mean, |
|
|
"train/onderwerp_threshold_min": onderwerp_thresh_min, |
|
|
"train/onderwerp_threshold_max": onderwerp_thresh_max, |
|
|
"train/beleving_threshold_mean": beleving_thresh_mean, |
|
|
"train/beleving_threshold_min": beleving_thresh_min, |
|
|
"train/beleving_threshold_max": beleving_thresh_max, |
|
|
} |
|
|
if threshold_lr is not None: |
|
|
log_dict["train/threshold_lr"] = threshold_lr |
|
|
wandb.log(log_dict, step=global_step) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
onderwerp_thresh_grad = (model.onderwerp_tau_logit.grad.abs().mean().item() |
|
|
if model.use_thresholds and model.onderwerp_tau_logit.grad is not None else 0.0) |
|
|
beleving_thresh_grad = (model.beleving_tau_logit.grad.abs().mean().item() |
|
|
if model.use_thresholds and model.beleving_tau_logit.grad is not None else 0.0) |
|
|
|
|
|
encoder_grad_norm = grad_l2_norm(model.encoder.parameters()) |
|
|
onderwerp_head_grad_norm = grad_l2_norm(model.onderwerp_head.parameters()) |
|
|
beleving_head_grad_norm = grad_l2_norm(model.beleving_head.parameters()) |
|
|
global_grad_norm = grad_l2_norm(model.parameters()) |
|
|
|
|
|
|
|
|
wandb.log({ |
|
|
"phase": phase_name, |
|
|
"grads/threshold_onderwerp": onderwerp_thresh_grad, |
|
|
"grads/threshold_beleving": beleving_thresh_grad, |
|
|
"grads/encoder": encoder_grad_norm, |
|
|
"grads/onderwerp_head": onderwerp_head_grad_norm, |
|
|
"grads/beleving_head": beleving_head_grad_norm, |
|
|
"grads/global_norm": global_grad_norm, |
|
|
}, step=global_step) |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
global_step += 1 |
|
|
total_loss += loss.item() |
|
|
total_onderwerp_f1 += onderwerp_f1.item() |
|
|
total_beleving_f1 += beleving_f1.item() |
|
|
total_f1_loss += (f1_loss / 2).item() |
|
|
total_bce_loss += (bce_loss / 2).item() |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
avg_loss = total_loss / max(1, num_batches) |
|
|
avg_onderwerp_f1 = total_onderwerp_f1 / max(1, num_batches) |
|
|
avg_beleving_f1 = total_beleving_f1 / max(1, num_batches) |
|
|
avg_f1_loss = total_f1_loss / max(1, num_batches) |
|
|
avg_bce_loss = total_bce_loss / max(1, num_batches) |
|
|
|
|
|
|
|
|
lrs = scheduler.get_last_lr() |
|
|
current_lr = lrs[0] |
|
|
|
|
|
|
|
|
if model.use_thresholds: |
|
|
onderwerp_thresh_mean = logit_to_prob(model.onderwerp_tau_logit).mean().item() |
|
|
onderwerp_thresh_std = logit_to_prob(model.onderwerp_tau_logit).std().item() |
|
|
beleving_thresh_mean = logit_to_prob(model.beleving_tau_logit).mean().item() |
|
|
beleving_thresh_std = logit_to_prob(model.beleving_tau_logit).std().item() |
|
|
else: |
|
|
onderwerp_thresh_mean = onderwerp_thresh_std = 0.5 |
|
|
beleving_thresh_mean = beleving_thresh_std = 0.5 |
|
|
|
|
|
print(f"\n [{phase_name.upper()}] Epoch {epoch + 1} Summary:") |
|
|
print(f" Average total loss: {avg_loss:.4f} (α={alpha} F1 + {1-alpha} BCE)") |
|
|
print(f" Average F1 loss: {avg_f1_loss:.4f} | Average BCE loss: {avg_bce_loss:.4f}") |
|
|
print(f" Average onderwerp F1: {avg_onderwerp_f1:.4f} | Threshold: {onderwerp_thresh_mean:.3f} (σ={onderwerp_thresh_std:.3f})") |
|
|
print(f" Average beleving F1: {avg_beleving_f1:.4f} | Threshold: {beleving_thresh_mean:.3f} (σ={beleving_thresh_std:.3f})") |
|
|
print(f" Average combined F1: {(avg_onderwerp_f1 + avg_beleving_f1) / 2:.4f}") |
|
|
print(f" Current learning rate: {current_lr:.2e}") |
|
|
|
|
|
|
|
|
print(f"\n Running validation on 200 samples...") |
|
|
val_metrics = evaluate( |
|
|
model, val_texts, val_onderwerp, val_beleving, tokenizer, device, |
|
|
onderwerp_names, beleving_names, num_samples=200, max_length=max_length |
|
|
) |
|
|
|
|
|
|
|
|
wandb.log({ |
|
|
"phase": phase_name, |
|
|
"val/onderwerp_acc": val_metrics['onderwerp_acc'], |
|
|
"val/onderwerp_precision": val_metrics['onderwerp_precision'], |
|
|
"val/onderwerp_recall": val_metrics['onderwerp_recall'], |
|
|
"val/onderwerp_f1": val_metrics['onderwerp_f1'], |
|
|
"val/beleving_acc": val_metrics['beleving_acc'], |
|
|
"val/beleving_precision": val_metrics['beleving_precision'], |
|
|
"val/beleving_recall": val_metrics['beleving_recall'], |
|
|
"val/beleving_f1": val_metrics['beleving_f1'], |
|
|
"val/combined_acc": val_metrics['combined_acc'], |
|
|
"val/combined_f1": val_metrics['combined_f1'], |
|
|
"val/onderwerp_threshold_mean": val_metrics['onderwerp_thresh_mean'], |
|
|
"val/beleving_threshold_mean": val_metrics['beleving_thresh_mean'], |
|
|
"epoch": epoch + 1 |
|
|
}, step=global_step) |
|
|
|
|
|
|
|
|
if model.use_thresholds: |
|
|
wandb.log({ |
|
|
"phase": phase_name, |
|
|
"thresholds/onderwerp": wandb.Histogram(logit_to_prob(model.onderwerp_tau_logit).detach().cpu().numpy()), |
|
|
"thresholds/beleving": wandb.Histogram(logit_to_prob(model.beleving_tau_logit).detach().cpu().numpy()), |
|
|
"epoch": epoch + 1 |
|
|
}, step=global_step) |
|
|
|
|
|
print(f" Val onderwerp F1: {val_metrics['onderwerp_f1']:.4f} | Val beleving F1: {val_metrics['beleving_f1']:.4f}") |
|
|
print(f" Val combined F1: {val_metrics['combined_f1']:.4f}") |
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
return global_step |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
|
device = get_device() |
|
|
|
|
|
|
|
|
|
|
|
model_name = "jhu-clsp/mmBERT-base" |
|
|
|
|
|
|
|
|
default_config = dict( |
|
|
|
|
|
seed=42, |
|
|
|
|
|
|
|
|
dropout=0.2, |
|
|
initial_threshold=0.565, |
|
|
max_length=1408, |
|
|
|
|
|
|
|
|
use_thresholds=False, |
|
|
|
|
|
|
|
|
encoder_peak_lr=8e-5, |
|
|
threshold_lr_mult=5.0, |
|
|
num_epochs=15, |
|
|
batch_size=16, |
|
|
|
|
|
|
|
|
alpha=0.15, |
|
|
temperature=2.0, |
|
|
|
|
|
|
|
|
warmup_ratio=0.1, |
|
|
min_lr=1e-6, |
|
|
) |
|
|
|
|
|
|
|
|
wandb.init(project="wim-multilabel-mmbert", config=default_config) |
|
|
cfg = wandb.config |
|
|
|
|
|
|
|
|
set_seed(cfg.seed) |
|
|
|
|
|
|
|
|
print("\nLoading RD dataset...") |
|
|
texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset( |
|
|
max_samples=None |
|
|
) |
|
|
|
|
|
print(f"\nDataset loaded:") |
|
|
print(f" Samples: {len(texts)}") |
|
|
print(f" Onderwerp labels: {len(onderwerp_names)}") |
|
|
print(f" Beleving labels: {len(beleving_names)}") |
|
|
print(f" Avg onderwerp per sample: {onderwerp.sum(axis=1).mean():.2f}") |
|
|
print(f" Avg beleving per sample: {beleving.sum(axis=1).mean():.2f}") |
|
|
|
|
|
|
|
|
dropout = cfg.dropout |
|
|
initial_threshold = cfg.initial_threshold |
|
|
max_length = cfg.max_length |
|
|
encoder_peak_lr = cfg.encoder_peak_lr |
|
|
threshold_peak_lr = encoder_peak_lr * cfg.threshold_lr_mult |
|
|
num_epochs = cfg.num_epochs |
|
|
batch_size = cfg.batch_size |
|
|
alpha = cfg.alpha |
|
|
temperature = cfg.temperature |
|
|
warmup_ratio = cfg.warmup_ratio |
|
|
min_lr = cfg.min_lr |
|
|
|
|
|
|
|
|
|
|
|
print("\nLoading mmBERT-base tokenizer and creating dual-head model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
model = mmBERTDualHead( |
|
|
model_name=model_name, |
|
|
num_onderwerp=len(onderwerp_names), |
|
|
num_beleving=len(beleving_names), |
|
|
dropout=dropout, |
|
|
initial_threshold=initial_threshold, |
|
|
use_thresholds=cfg.use_thresholds |
|
|
) |
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
encoder_dtype = next(model.encoder.parameters()).dtype |
|
|
with torch.no_grad(): |
|
|
if model.use_thresholds: |
|
|
model.onderwerp_tau_logit.copy_(model.onderwerp_tau_logit.to(encoder_dtype)) |
|
|
model.beleving_tau_logit.copy_(model.beleving_tau_logit.to(encoder_dtype)) |
|
|
|
|
|
print(f"Model loaded and moved to {device}") |
|
|
print(f" Onderwerp head: {len(onderwerp_names)} outputs") |
|
|
print(f" Beleving head: {len(beleving_names)} outputs") |
|
|
|
|
|
|
|
|
split_idx = int(0.8 * len(texts)) |
|
|
train_texts = texts[:split_idx] |
|
|
train_onderwerp = onderwerp[:split_idx] |
|
|
train_beleving = beleving[:split_idx] |
|
|
val_texts = texts[split_idx:] |
|
|
val_onderwerp = onderwerp[split_idx:] |
|
|
val_beleving = beleving[split_idx:] |
|
|
|
|
|
print(f"\nData split:") |
|
|
print(f" Train: {len(train_texts)} samples") |
|
|
print(f" Val: {len(val_texts)} samples") |
|
|
|
|
|
|
|
|
train_dataset = DutchDualLabelDataset( |
|
|
train_texts, train_onderwerp, train_beleving, tokenizer, max_length |
|
|
) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
steps_per_epoch = len(train_loader) |
|
|
total_training_steps = steps_per_epoch * num_epochs |
|
|
|
|
|
|
|
|
wandb.config.update({ |
|
|
|
|
|
"model_name": model_name, |
|
|
"num_onderwerp": len(onderwerp_names), |
|
|
"num_beleving": len(beleving_names), |
|
|
|
|
|
|
|
|
"threshold_peak_lr": threshold_peak_lr, |
|
|
"total_training_steps": total_training_steps, |
|
|
|
|
|
|
|
|
"train_samples": len(train_texts), |
|
|
"val_samples": len(val_texts), |
|
|
"total_samples": len(texts), |
|
|
"split_ratio": 0.8, |
|
|
|
|
|
|
|
|
"loss_type": "combined_f1_bce", |
|
|
"f1_weight": alpha, |
|
|
"bce_weight": 1 - alpha, |
|
|
|
|
|
|
|
|
"learnable_thresholds": cfg.use_thresholds, |
|
|
"per_class_thresholds": cfg.use_thresholds, |
|
|
"gradient_clipping": True, |
|
|
"max_grad_norm": 1.0, |
|
|
}, allow_val_change=True) |
|
|
|
|
|
|
|
|
print(f"\nStarting training for {num_epochs} total epochs with COMBINED F1+BCE LOSS...") |
|
|
print(f"Loss formula: {alpha} * (1-F1) + {1-alpha} * BCE") |
|
|
print(f"Temperature for Soft-F1: {temperature} | Initial thresholds: {initial_threshold}") |
|
|
print(f"Batch size: {batch_size} | Total training batches: {steps_per_epoch}") |
|
|
print(f"Learnable thresholds enabled for both onderwerp and beleving heads") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"TRAINING: {num_epochs} epoch(s)") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
optimizer, scheduler = make_opt_sched( |
|
|
model, |
|
|
enc_lr=encoder_peak_lr, |
|
|
thr_lr=threshold_peak_lr, |
|
|
total_steps=total_training_steps, |
|
|
warmup_ratio=warmup_ratio, |
|
|
eta_min=min_lr |
|
|
) |
|
|
|
|
|
|
|
|
global_step = run_epochs( |
|
|
model, tokenizer, train_loader, |
|
|
val_texts, val_onderwerp, val_beleving, |
|
|
onderwerp_names, beleving_names, device, |
|
|
start_epoch=0, end_epoch=num_epochs, |
|
|
phase_name="train", |
|
|
optimizer=optimizer, scheduler=scheduler, |
|
|
temperature=temperature, alpha=alpha, |
|
|
max_length=max_length, global_step=0 |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("TRAINING COMPLETE") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("FINAL EVALUATION ON VALIDATION SET") |
|
|
print("=" * 60) |
|
|
|
|
|
print(f"\nEvaluating on 500 validation samples...") |
|
|
final_metrics = evaluate( |
|
|
model, val_texts, val_onderwerp, val_beleving, tokenizer, device, |
|
|
onderwerp_names, beleving_names, num_samples=500, max_length=max_length |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(f"FINAL METRICS (on {final_metrics['num_samples_evaluated']} validation samples)") |
|
|
print("-" * 40) |
|
|
|
|
|
print(f" Onderwerp:") |
|
|
print(f" Accuracy: {final_metrics['onderwerp_acc']:.1%}") |
|
|
print(f" Precision: {final_metrics['onderwerp_precision']:.3f}") |
|
|
print(f" Recall: {final_metrics['onderwerp_recall']:.3f}") |
|
|
print(f" F1 Score: {final_metrics['onderwerp_f1']:.3f}") |
|
|
|
|
|
print(f"\n Beleving:") |
|
|
print(f" Accuracy: {final_metrics['beleving_acc']:.1%}") |
|
|
print(f" Precision: {final_metrics['beleving_precision']:.3f}") |
|
|
print(f" Recall: {final_metrics['beleving_recall']:.3f}") |
|
|
print(f" F1 Score: {final_metrics['beleving_f1']:.3f}") |
|
|
|
|
|
print(f"\n Combined:") |
|
|
print(f" Average Accuracy: {final_metrics['combined_acc']:.1%}") |
|
|
print(f" Average F1: {final_metrics['combined_f1']:.3f}") |
|
|
|
|
|
|
|
|
wandb.log({ |
|
|
"final/onderwerp_acc": final_metrics['onderwerp_acc'], |
|
|
"final/onderwerp_precision": final_metrics['onderwerp_precision'], |
|
|
"final/onderwerp_recall": final_metrics['onderwerp_recall'], |
|
|
"final/onderwerp_f1": final_metrics['onderwerp_f1'], |
|
|
"final/beleving_acc": final_metrics['beleving_acc'], |
|
|
"final/beleving_precision": final_metrics['beleving_precision'], |
|
|
"final/beleving_recall": final_metrics['beleving_recall'], |
|
|
"final/beleving_f1": final_metrics['beleving_f1'], |
|
|
"final/combined_acc": final_metrics['combined_acc'], |
|
|
"final/combined_f1": final_metrics['combined_f1'], |
|
|
}, step=global_step) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Training complete! 🎉") |
|
|
print("mmBERT-base dual-head architecture with balanced F1+BCE loss") |
|
|
print(f"Loss formula: {alpha} * (1-F1) + {1-alpha} * BCE") |
|
|
print(f"Temperature: {temperature}") |
|
|
if cfg.use_thresholds: |
|
|
print(f"Learned per-class thresholds:") |
|
|
print(f" Onderwerp ({len(onderwerp_names)} classes): mean={final_metrics['onderwerp_thresh_mean']:.3f} [{final_metrics['onderwerp_thresh_min']:.3f}-{final_metrics['onderwerp_thresh_max']:.3f}] σ={final_metrics['onderwerp_thresh_std']:.3f}") |
|
|
print(f" Beleving ({len(beleving_names)} classes): mean={final_metrics['beleving_thresh_mean']:.3f} [{final_metrics['beleving_thresh_min']:.3f}-{final_metrics['beleving_thresh_max']:.3f}] σ={final_metrics['beleving_thresh_std']:.3f}") |
|
|
else: |
|
|
print("Thresholds disabled (fixed cutoff τ=0.5 for both heads).") |
|
|
print(f"With gradient clipping (max_norm=1.0) and warmup LR schedule") |
|
|
print(f"Full dataset: {len(texts)} samples | Batch size: {batch_size} | Epochs: {num_epochs}") |
|
|
print(f"mmBERT: Modern multilingual encoder (1800+ languages, max_length: {max_length})") |
|
|
|
|
|
|
|
|
save_path = "mmbert_dual_head_final.pt" |
|
|
torch.save(model.state_dict(), save_path) |
|
|
print(f"\nModel weights saved to {save_path}") |
|
|
|
|
|
|
|
|
hf_dir = "mmbert_dual_head_hf" |
|
|
os.makedirs(hf_dir, exist_ok=True) |
|
|
|
|
|
model.encoder.save_pretrained(hf_dir) |
|
|
tokenizer.save_pretrained(hf_dir) |
|
|
|
|
|
head_state = { |
|
|
"onderwerp_head_state": model.onderwerp_head.state_dict(), |
|
|
"beleving_head_state": model.beleving_head.state_dict(), |
|
|
"use_thresholds": model.use_thresholds, |
|
|
"num_onderwerp": len(onderwerp_names), |
|
|
"num_beleving": len(beleving_names), |
|
|
"dropout": dropout, |
|
|
"max_length": max_length, |
|
|
"alpha": alpha, |
|
|
"temperature": temperature, |
|
|
"model_name": model_name, |
|
|
} |
|
|
if model.use_thresholds: |
|
|
head_state["onderwerp_tau_logit"] = model.onderwerp_tau_logit.detach().cpu() |
|
|
head_state["beleving_tau_logit"] = model.beleving_tau_logit.detach().cpu() |
|
|
torch.save(head_state, os.path.join(hf_dir, "dual_head_state.pt")) |
|
|
|
|
|
with open(os.path.join(hf_dir, "label_names.json"), "w") as f: |
|
|
json.dump({ |
|
|
"onderwerp": list(map(str, onderwerp_names)), |
|
|
"beleving": list(map(str, beleving_names)) |
|
|
}, f, ensure_ascii=False, indent=2) |
|
|
print(f"HF-compatible checkpoint saved to '{hf_dir}' (encoder+tokenizer), with heads in dual_head_state.pt") |
|
|
|
|
|
|
|
|
wandb.finish() |
|
|
print("\nWandB logging completed and run finished.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|