#!/usr/bin/env python3 """ UWV/wim-synthetic-data-rd dataset loader for multi-label classification. Carmack-style: minimal abstraction, direct data flow, fast operations. """ import numpy as np from datasets import load_dataset, concatenate_datasets def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True): """ Load combined UWV datasets and encode multi-labels. Combines two datasets: - UWV/wim-synthetic-data-rd: Original RD dataset - UWV/wim_synthetic_data_for_testing_split_labels: Validated testing dataset Dataset contains Dutch municipal complaint conversations with two types of labels: - onderwerp: What the message is about - beleving: How the citizen experienced the interaction Args: max_samples: Limit number of samples (None = all samples) split: Dataset split to load (default: 'train') filter_calamity: If True, exclude samples with is_calamity=True from RD dataset (default: True) Returns: texts: List of conversation strings onderwerp_encoded: numpy array [n_samples, n_onderwerp] - multi-hot encoded topics beleving_encoded: numpy array [n_samples, n_beleving] - multi-hot encoded experiences onderwerp_labels: List of onderwerp label names (sorted alphabetically) beleving_labels: List of beleving label names (sorted alphabetically) """ # Load RD dataset print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...") ds_rd = load_dataset('UWV/wim-synthetic-data-rd', split=split) # Filter out calamity samples if requested if filter_calamity: original_len = len(ds_rd) ds_rd = ds_rd.filter(lambda x: not x['is_calamity']) filtered_len = len(ds_rd) print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)") # Keep only essential columns from RD dataset ds_rd = ds_rd.select_columns(['text', 'onderwerp_labels', 'beleving_labels']) print(f"RD dataset: {len(ds_rd)} samples") # Load testing dataset print(f"Loading UWV/wim_synthetic_data_for_testing_split_labels dataset (split={split})...") ds_test = load_dataset('UWV/wim_synthetic_data_for_testing_split_labels', split=split) # Rename columns to match RD dataset structure ds_test = ds_test.map(lambda x: { 'text': x['Synthetic Text'], 'onderwerp_labels': x['validated_onderwerp_labels'], 'beleving_labels': x['validated_beleving_labels'] }, remove_columns=ds_test.column_names) print(f"Testing dataset: {len(ds_test)} samples") # Concatenate datasets ds = concatenate_datasets([ds_rd, ds_test]) print(f"Combined dataset: {len(ds)} samples") # Shuffle with fixed seed for reproducibility ds = ds.shuffle(seed=42) print(f"Shuffled combined dataset") # Replace "No subtopic found" with empty list (for both onderwerp and beleving) ds = ds.map(lambda x: { **x, 'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels'], 'beleving_labels': [] if x['beleving_labels'] == ['No subtopic found'] else x['beleving_labels'] }) no_onderwerp_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0) no_beleving_count = sum(1 for sample in ds if len(sample['beleving_labels']) == 0) print(f"Replaced 'No subtopic found' with empty list: {no_onderwerp_count} onderwerp, {no_beleving_count} beleving") # Limit samples if requested if max_samples is not None: ds = ds.select(range(min(max_samples, len(ds)))) print(f"Loaded {len(ds)} samples") # Extract all unique labels from the entire dataset onderwerp_set = set() beleving_set = set() for sample in ds: for label in sample['onderwerp_labels']: onderwerp_set.add(label) for label in sample['beleving_labels']: beleving_set.add(label) # Sort labels alphabetically for consistent indexing across runs onderwerp_labels = sorted(onderwerp_set) beleving_labels = sorted(beleving_set) print(f"Found {len(onderwerp_labels)} unique onderwerp labels") print(f"Found {len(beleving_labels)} unique beleving labels") # Create label -> index mappings onderwerp_to_idx = {label: idx for idx, label in enumerate(onderwerp_labels)} beleving_to_idx = {label: idx for idx, label in enumerate(beleving_labels)} # Encode labels to multi-hot vectors n_samples = len(ds) n_onderwerp = len(onderwerp_labels) n_beleving = len(beleving_labels) # Preallocate arrays (faster than appending) texts = [] onderwerp_encoded = np.zeros((n_samples, n_onderwerp), dtype=np.float32) beleving_encoded = np.zeros((n_samples, n_beleving), dtype=np.float32) # Fill arrays for i, sample in enumerate(ds): texts.append(sample['text']) # Encode onderwerp labels (multi-hot) for label in sample['onderwerp_labels']: idx = onderwerp_to_idx[label] onderwerp_encoded[i, idx] = 1.0 # Encode beleving labels (multi-hot) for label in sample['beleving_labels']: idx = beleving_to_idx[label] beleving_encoded[i, idx] = 1.0 print(f"Encoded {n_samples} samples") print(f" onderwerp shape: {onderwerp_encoded.shape}") print(f" beleving shape: {beleving_encoded.shape}") return texts, onderwerp_encoded, beleving_encoded, onderwerp_labels, beleving_labels def print_sample_info(texts, onderwerp_encoded, beleving_encoded, onderwerp_labels, beleving_labels, sample_idx=0): """ Print information about a specific sample (useful for debugging). Args: All outputs from load_uwv_wim_dataset() sample_idx: Which sample to print (default: 0) """ print(f"\n{'='*60}") print(f"SAMPLE {sample_idx}") print(f"{'='*60}") print(f"Text: {texts[sample_idx][:200]}...") print() # Get active onderwerp labels onderwerp_active = [onderwerp_labels[i] for i, val in enumerate(onderwerp_encoded[sample_idx]) if val == 1] print(f"Onderwerp labels ({len(onderwerp_active)}):") for label in onderwerp_active: print(f" - {label}") print() # Get active beleving labels beleving_active = [beleving_labels[i] for i, val in enumerate(beleving_encoded[sample_idx]) if val == 1] print(f"Beleving labels ({len(beleving_active)}):") for label in beleving_active: print(f" - {label}") print(f"{'='*60}\n") if __name__ == "__main__": # Test the loader print("Testing UWV dataset loader...\n") # Load small subset for testing texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(max_samples=10) # Print first sample print_sample_info(texts, onderwerp, beleving, onderwerp_names, beleving_names, sample_idx=0) # Print statistics print("\nDataset Statistics:") print(f" Total samples: {len(texts)}") print(f" Avg onderwerp labels per sample: {onderwerp.sum(axis=1).mean():.2f}") print(f" Avg beleving labels per sample: {beleving.sum(axis=1).mean():.2f}") print(f" Text length range: {min(len(t) for t in texts)} - {max(len(t) for t in texts)} chars")