|
|
|
|
|
""" |
|
|
UWV/wim-synthetic-data-rd dataset loader for multi-label classification. |
|
|
Carmack-style: minimal abstraction, direct data flow, fast operations. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
from datasets import load_dataset, concatenate_datasets |
|
|
|
|
|
|
|
|
def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True): |
|
|
""" |
|
|
Load combined UWV datasets and encode multi-labels. |
|
|
|
|
|
Combines two datasets: |
|
|
- UWV/wim-synthetic-data-rd: Original RD dataset |
|
|
- UWV/wim_synthetic_data_for_testing_split_labels: Validated testing dataset |
|
|
|
|
|
Dataset contains Dutch municipal complaint conversations with two types of labels: |
|
|
- onderwerp: What the message is about |
|
|
- beleving: How the citizen experienced the interaction |
|
|
|
|
|
Args: |
|
|
max_samples: Limit number of samples (None = all samples) |
|
|
split: Dataset split to load (default: 'train') |
|
|
filter_calamity: If True, exclude samples with is_calamity=True from RD dataset (default: True) |
|
|
|
|
|
Returns: |
|
|
texts: List of conversation strings |
|
|
onderwerp_encoded: numpy array [n_samples, n_onderwerp] - multi-hot encoded topics |
|
|
beleving_encoded: numpy array [n_samples, n_beleving] - multi-hot encoded experiences |
|
|
onderwerp_labels: List of onderwerp label names (sorted alphabetically) |
|
|
beleving_labels: List of beleving label names (sorted alphabetically) |
|
|
""" |
|
|
|
|
|
|
|
|
print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...") |
|
|
ds_rd = load_dataset('UWV/wim-synthetic-data-rd', split=split) |
|
|
|
|
|
|
|
|
if filter_calamity: |
|
|
original_len = len(ds_rd) |
|
|
ds_rd = ds_rd.filter(lambda x: not x['is_calamity']) |
|
|
filtered_len = len(ds_rd) |
|
|
print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)") |
|
|
|
|
|
|
|
|
ds_rd = ds_rd.select_columns(['text', 'onderwerp_labels', 'beleving_labels']) |
|
|
print(f"RD dataset: {len(ds_rd)} samples") |
|
|
|
|
|
|
|
|
print(f"Loading UWV/wim_synthetic_data_for_testing_split_labels dataset (split={split})...") |
|
|
ds_test = load_dataset('UWV/wim_synthetic_data_for_testing_split_labels', split=split) |
|
|
|
|
|
|
|
|
ds_test = ds_test.map(lambda x: { |
|
|
'text': x['Synthetic Text'], |
|
|
'onderwerp_labels': x['validated_onderwerp_labels'], |
|
|
'beleving_labels': x['validated_beleving_labels'] |
|
|
}, remove_columns=ds_test.column_names) |
|
|
print(f"Testing dataset: {len(ds_test)} samples") |
|
|
|
|
|
|
|
|
ds = concatenate_datasets([ds_rd, ds_test]) |
|
|
print(f"Combined dataset: {len(ds)} samples") |
|
|
|
|
|
|
|
|
ds = ds.shuffle(seed=42) |
|
|
print(f"Shuffled combined dataset") |
|
|
|
|
|
|
|
|
ds = ds.map(lambda x: { |
|
|
**x, |
|
|
'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels'], |
|
|
'beleving_labels': [] if x['beleving_labels'] == ['No subtopic found'] else x['beleving_labels'] |
|
|
}) |
|
|
no_onderwerp_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0) |
|
|
no_beleving_count = sum(1 for sample in ds if len(sample['beleving_labels']) == 0) |
|
|
print(f"Replaced 'No subtopic found' with empty list: {no_onderwerp_count} onderwerp, {no_beleving_count} beleving") |
|
|
|
|
|
|
|
|
if max_samples is not None: |
|
|
ds = ds.select(range(min(max_samples, len(ds)))) |
|
|
|
|
|
print(f"Loaded {len(ds)} samples") |
|
|
|
|
|
|
|
|
onderwerp_set = set() |
|
|
beleving_set = set() |
|
|
|
|
|
for sample in ds: |
|
|
for label in sample['onderwerp_labels']: |
|
|
onderwerp_set.add(label) |
|
|
for label in sample['beleving_labels']: |
|
|
beleving_set.add(label) |
|
|
|
|
|
|
|
|
onderwerp_labels = sorted(onderwerp_set) |
|
|
beleving_labels = sorted(beleving_set) |
|
|
|
|
|
print(f"Found {len(onderwerp_labels)} unique onderwerp labels") |
|
|
print(f"Found {len(beleving_labels)} unique beleving labels") |
|
|
|
|
|
|
|
|
onderwerp_to_idx = {label: idx for idx, label in enumerate(onderwerp_labels)} |
|
|
beleving_to_idx = {label: idx for idx, label in enumerate(beleving_labels)} |
|
|
|
|
|
|
|
|
n_samples = len(ds) |
|
|
n_onderwerp = len(onderwerp_labels) |
|
|
n_beleving = len(beleving_labels) |
|
|
|
|
|
|
|
|
texts = [] |
|
|
onderwerp_encoded = np.zeros((n_samples, n_onderwerp), dtype=np.float32) |
|
|
beleving_encoded = np.zeros((n_samples, n_beleving), dtype=np.float32) |
|
|
|
|
|
|
|
|
for i, sample in enumerate(ds): |
|
|
texts.append(sample['text']) |
|
|
|
|
|
|
|
|
for label in sample['onderwerp_labels']: |
|
|
idx = onderwerp_to_idx[label] |
|
|
onderwerp_encoded[i, idx] = 1.0 |
|
|
|
|
|
|
|
|
for label in sample['beleving_labels']: |
|
|
idx = beleving_to_idx[label] |
|
|
beleving_encoded[i, idx] = 1.0 |
|
|
|
|
|
print(f"Encoded {n_samples} samples") |
|
|
print(f" onderwerp shape: {onderwerp_encoded.shape}") |
|
|
print(f" beleving shape: {beleving_encoded.shape}") |
|
|
|
|
|
return texts, onderwerp_encoded, beleving_encoded, onderwerp_labels, beleving_labels |
|
|
|
|
|
|
|
|
def print_sample_info(texts, onderwerp_encoded, beleving_encoded, |
|
|
onderwerp_labels, beleving_labels, sample_idx=0): |
|
|
""" |
|
|
Print information about a specific sample (useful for debugging). |
|
|
|
|
|
Args: |
|
|
All outputs from load_uwv_wim_dataset() |
|
|
sample_idx: Which sample to print (default: 0) |
|
|
""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"SAMPLE {sample_idx}") |
|
|
print(f"{'='*60}") |
|
|
print(f"Text: {texts[sample_idx][:200]}...") |
|
|
print() |
|
|
|
|
|
|
|
|
onderwerp_active = [onderwerp_labels[i] for i, val in enumerate(onderwerp_encoded[sample_idx]) if val == 1] |
|
|
print(f"Onderwerp labels ({len(onderwerp_active)}):") |
|
|
for label in onderwerp_active: |
|
|
print(f" - {label}") |
|
|
print() |
|
|
|
|
|
|
|
|
beleving_active = [beleving_labels[i] for i, val in enumerate(beleving_encoded[sample_idx]) if val == 1] |
|
|
print(f"Beleving labels ({len(beleving_active)}):") |
|
|
for label in beleving_active: |
|
|
print(f" - {label}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Testing UWV dataset loader...\n") |
|
|
|
|
|
|
|
|
texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(max_samples=10) |
|
|
|
|
|
|
|
|
print_sample_info(texts, onderwerp, beleving, onderwerp_names, beleving_names, sample_idx=0) |
|
|
|
|
|
|
|
|
print("\nDataset Statistics:") |
|
|
print(f" Total samples: {len(texts)}") |
|
|
print(f" Avg onderwerp labels per sample: {onderwerp.sum(axis=1).mean():.2f}") |
|
|
print(f" Avg beleving labels per sample: {beleving.sum(axis=1).mean():.2f}") |
|
|
print(f" Text length range: {min(len(t) for t in texts)} - {max(len(t) for t in texts)} chars") |
|
|
|