File size: 7,328 Bytes
bdd5464 23bf736 bdd5464 13d4fa0 bdd5464 23bf736 bdd5464 23bf736 bdd5464 13d4fa0 bdd5464 23bf736 bdd5464 13d4fa0 bdd5464 23bf736 bdd5464 23bf736 13d4fa0 23bf736 13d4fa0 23bf736 13d4fa0 23bf736 7a24f2a 23bf736 7a24f2a 23bf736 7a24f2a bdd5464 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
#!/usr/bin/env python3
"""
UWV/wim-synthetic-data-rd dataset loader for multi-label classification.
Carmack-style: minimal abstraction, direct data flow, fast operations.
"""
import numpy as np
from datasets import load_dataset, concatenate_datasets
def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
"""
Load combined UWV datasets and encode multi-labels.
Combines two datasets:
- UWV/wim-synthetic-data-rd: Original RD dataset
- UWV/wim_synthetic_data_for_testing_split_labels: Validated testing dataset
Dataset contains Dutch municipal complaint conversations with two types of labels:
- onderwerp: What the message is about
- beleving: How the citizen experienced the interaction
Args:
max_samples: Limit number of samples (None = all samples)
split: Dataset split to load (default: 'train')
filter_calamity: If True, exclude samples with is_calamity=True from RD dataset (default: True)
Returns:
texts: List of conversation strings
onderwerp_encoded: numpy array [n_samples, n_onderwerp] - multi-hot encoded topics
beleving_encoded: numpy array [n_samples, n_beleving] - multi-hot encoded experiences
onderwerp_labels: List of onderwerp label names (sorted alphabetically)
beleving_labels: List of beleving label names (sorted alphabetically)
"""
# Load RD dataset
print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
ds_rd = load_dataset('UWV/wim-synthetic-data-rd', split=split)
# Filter out calamity samples if requested
if filter_calamity:
original_len = len(ds_rd)
ds_rd = ds_rd.filter(lambda x: not x['is_calamity'])
filtered_len = len(ds_rd)
print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
# Keep only essential columns from RD dataset
ds_rd = ds_rd.select_columns(['text', 'onderwerp_labels', 'beleving_labels'])
print(f"RD dataset: {len(ds_rd)} samples")
# Load testing dataset
print(f"Loading UWV/wim_synthetic_data_for_testing_split_labels dataset (split={split})...")
ds_test = load_dataset('UWV/wim_synthetic_data_for_testing_split_labels', split=split)
# Rename columns to match RD dataset structure
ds_test = ds_test.map(lambda x: {
'text': x['Synthetic Text'],
'onderwerp_labels': x['validated_onderwerp_labels'],
'beleving_labels': x['validated_beleving_labels']
}, remove_columns=ds_test.column_names)
print(f"Testing dataset: {len(ds_test)} samples")
# Concatenate datasets
ds = concatenate_datasets([ds_rd, ds_test])
print(f"Combined dataset: {len(ds)} samples")
# Shuffle with fixed seed for reproducibility
ds = ds.shuffle(seed=42)
print(f"Shuffled combined dataset")
# Replace "No subtopic found" with empty list (for both onderwerp and beleving)
ds = ds.map(lambda x: {
**x,
'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels'],
'beleving_labels': [] if x['beleving_labels'] == ['No subtopic found'] else x['beleving_labels']
})
no_onderwerp_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0)
no_beleving_count = sum(1 for sample in ds if len(sample['beleving_labels']) == 0)
print(f"Replaced 'No subtopic found' with empty list: {no_onderwerp_count} onderwerp, {no_beleving_count} beleving")
# Limit samples if requested
if max_samples is not None:
ds = ds.select(range(min(max_samples, len(ds))))
print(f"Loaded {len(ds)} samples")
# Extract all unique labels from the entire dataset
onderwerp_set = set()
beleving_set = set()
for sample in ds:
for label in sample['onderwerp_labels']:
onderwerp_set.add(label)
for label in sample['beleving_labels']:
beleving_set.add(label)
# Sort labels alphabetically for consistent indexing across runs
onderwerp_labels = sorted(onderwerp_set)
beleving_labels = sorted(beleving_set)
print(f"Found {len(onderwerp_labels)} unique onderwerp labels")
print(f"Found {len(beleving_labels)} unique beleving labels")
# Create label -> index mappings
onderwerp_to_idx = {label: idx for idx, label in enumerate(onderwerp_labels)}
beleving_to_idx = {label: idx for idx, label in enumerate(beleving_labels)}
# Encode labels to multi-hot vectors
n_samples = len(ds)
n_onderwerp = len(onderwerp_labels)
n_beleving = len(beleving_labels)
# Preallocate arrays (faster than appending)
texts = []
onderwerp_encoded = np.zeros((n_samples, n_onderwerp), dtype=np.float32)
beleving_encoded = np.zeros((n_samples, n_beleving), dtype=np.float32)
# Fill arrays
for i, sample in enumerate(ds):
texts.append(sample['text'])
# Encode onderwerp labels (multi-hot)
for label in sample['onderwerp_labels']:
idx = onderwerp_to_idx[label]
onderwerp_encoded[i, idx] = 1.0
# Encode beleving labels (multi-hot)
for label in sample['beleving_labels']:
idx = beleving_to_idx[label]
beleving_encoded[i, idx] = 1.0
print(f"Encoded {n_samples} samples")
print(f" onderwerp shape: {onderwerp_encoded.shape}")
print(f" beleving shape: {beleving_encoded.shape}")
return texts, onderwerp_encoded, beleving_encoded, onderwerp_labels, beleving_labels
def print_sample_info(texts, onderwerp_encoded, beleving_encoded,
onderwerp_labels, beleving_labels, sample_idx=0):
"""
Print information about a specific sample (useful for debugging).
Args:
All outputs from load_uwv_wim_dataset()
sample_idx: Which sample to print (default: 0)
"""
print(f"\n{'='*60}")
print(f"SAMPLE {sample_idx}")
print(f"{'='*60}")
print(f"Text: {texts[sample_idx][:200]}...")
print()
# Get active onderwerp labels
onderwerp_active = [onderwerp_labels[i] for i, val in enumerate(onderwerp_encoded[sample_idx]) if val == 1]
print(f"Onderwerp labels ({len(onderwerp_active)}):")
for label in onderwerp_active:
print(f" - {label}")
print()
# Get active beleving labels
beleving_active = [beleving_labels[i] for i, val in enumerate(beleving_encoded[sample_idx]) if val == 1]
print(f"Beleving labels ({len(beleving_active)}):")
for label in beleving_active:
print(f" - {label}")
print(f"{'='*60}\n")
if __name__ == "__main__":
# Test the loader
print("Testing UWV dataset loader...\n")
# Load small subset for testing
texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(max_samples=10)
# Print first sample
print_sample_info(texts, onderwerp, beleving, onderwerp_names, beleving_names, sample_idx=0)
# Print statistics
print("\nDataset Statistics:")
print(f" Total samples: {len(texts)}")
print(f" Avg onderwerp labels per sample: {onderwerp.sum(axis=1).mean():.2f}")
print(f" Avg beleving labels per sample: {beleving.sum(axis=1).mean():.2f}")
print(f" Text length range: {min(len(t) for t in texts)} - {max(len(t) for t in texts)} chars")
|