File size: 7,328 Bytes
bdd5464
 
 
 
 
 
 
23bf736
bdd5464
 
13d4fa0
bdd5464
23bf736
 
 
 
 
 
bdd5464
23bf736
 
bdd5464
 
13d4fa0
bdd5464
23bf736
bdd5464
 
 
13d4fa0
 
 
 
bdd5464
 
23bf736
bdd5464
23bf736
 
13d4fa0
 
23bf736
 
 
13d4fa0
23bf736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13d4fa0
23bf736
7a24f2a
 
23bf736
 
7a24f2a
23bf736
 
 
7a24f2a
bdd5464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
"""
UWV/wim-synthetic-data-rd dataset loader for multi-label classification.
Carmack-style: minimal abstraction, direct data flow, fast operations.
"""

import numpy as np
from datasets import load_dataset, concatenate_datasets


def load_rd_wim_dataset(max_samples=None, split='train', filter_calamity=True):
    """
    Load combined UWV datasets and encode multi-labels.
    
    Combines two datasets:
    - UWV/wim-synthetic-data-rd: Original RD dataset
    - UWV/wim_synthetic_data_for_testing_split_labels: Validated testing dataset
    
    Dataset contains Dutch municipal complaint conversations with two types of labels:
    - onderwerp: What the message is about
    - beleving: How the citizen experienced the interaction

    Args:
        max_samples: Limit number of samples (None = all samples)
        split: Dataset split to load (default: 'train')
        filter_calamity: If True, exclude samples with is_calamity=True from RD dataset (default: True)

    Returns:
        texts: List of conversation strings
        onderwerp_encoded: numpy array [n_samples, n_onderwerp] - multi-hot encoded topics
        beleving_encoded: numpy array [n_samples, n_beleving] - multi-hot encoded experiences
        onderwerp_labels: List of onderwerp label names (sorted alphabetically)
        beleving_labels: List of beleving label names (sorted alphabetically)
    """

    # Load RD dataset
    print(f"Loading UWV/wim-synthetic-data-rd dataset (split={split})...")
    ds_rd = load_dataset('UWV/wim-synthetic-data-rd', split=split)
    
    # Filter out calamity samples if requested
    if filter_calamity:
        original_len = len(ds_rd)
        ds_rd = ds_rd.filter(lambda x: not x['is_calamity'])
        filtered_len = len(ds_rd)
        print(f"Filtered out {original_len - filtered_len} calamity samples ({filtered_len} remaining)")
    
    # Keep only essential columns from RD dataset
    ds_rd = ds_rd.select_columns(['text', 'onderwerp_labels', 'beleving_labels'])
    print(f"RD dataset: {len(ds_rd)} samples")
    
    # Load testing dataset
    print(f"Loading UWV/wim_synthetic_data_for_testing_split_labels dataset (split={split})...")
    ds_test = load_dataset('UWV/wim_synthetic_data_for_testing_split_labels', split=split)
    
    # Rename columns to match RD dataset structure
    ds_test = ds_test.map(lambda x: {
        'text': x['Synthetic Text'],
        'onderwerp_labels': x['validated_onderwerp_labels'],
        'beleving_labels': x['validated_beleving_labels']
    }, remove_columns=ds_test.column_names)
    print(f"Testing dataset: {len(ds_test)} samples")
    
    # Concatenate datasets
    ds = concatenate_datasets([ds_rd, ds_test])
    print(f"Combined dataset: {len(ds)} samples")
    
    # Shuffle with fixed seed for reproducibility
    ds = ds.shuffle(seed=42)
    print(f"Shuffled combined dataset")

    # Replace "No subtopic found" with empty list (for both onderwerp and beleving)
    ds = ds.map(lambda x: {
        **x,
        'onderwerp_labels': [] if x['onderwerp_labels'] == ['No subtopic found'] else x['onderwerp_labels'],
        'beleving_labels': [] if x['beleving_labels'] == ['No subtopic found'] else x['beleving_labels']
    })
    no_onderwerp_count = sum(1 for sample in ds if len(sample['onderwerp_labels']) == 0)
    no_beleving_count = sum(1 for sample in ds if len(sample['beleving_labels']) == 0)
    print(f"Replaced 'No subtopic found' with empty list: {no_onderwerp_count} onderwerp, {no_beleving_count} beleving")

    # Limit samples if requested
    if max_samples is not None:
        ds = ds.select(range(min(max_samples, len(ds))))

    print(f"Loaded {len(ds)} samples")

    # Extract all unique labels from the entire dataset
    onderwerp_set = set()
    beleving_set = set()

    for sample in ds:
        for label in sample['onderwerp_labels']:
            onderwerp_set.add(label)
        for label in sample['beleving_labels']:
            beleving_set.add(label)

    # Sort labels alphabetically for consistent indexing across runs
    onderwerp_labels = sorted(onderwerp_set)
    beleving_labels = sorted(beleving_set)

    print(f"Found {len(onderwerp_labels)} unique onderwerp labels")
    print(f"Found {len(beleving_labels)} unique beleving labels")

    # Create label -> index mappings
    onderwerp_to_idx = {label: idx for idx, label in enumerate(onderwerp_labels)}
    beleving_to_idx = {label: idx for idx, label in enumerate(beleving_labels)}

    # Encode labels to multi-hot vectors
    n_samples = len(ds)
    n_onderwerp = len(onderwerp_labels)
    n_beleving = len(beleving_labels)

    # Preallocate arrays (faster than appending)
    texts = []
    onderwerp_encoded = np.zeros((n_samples, n_onderwerp), dtype=np.float32)
    beleving_encoded = np.zeros((n_samples, n_beleving), dtype=np.float32)

    # Fill arrays
    for i, sample in enumerate(ds):
        texts.append(sample['text'])

        # Encode onderwerp labels (multi-hot)
        for label in sample['onderwerp_labels']:
            idx = onderwerp_to_idx[label]
            onderwerp_encoded[i, idx] = 1.0

        # Encode beleving labels (multi-hot)
        for label in sample['beleving_labels']:
            idx = beleving_to_idx[label]
            beleving_encoded[i, idx] = 1.0

    print(f"Encoded {n_samples} samples")
    print(f"  onderwerp shape: {onderwerp_encoded.shape}")
    print(f"  beleving shape: {beleving_encoded.shape}")

    return texts, onderwerp_encoded, beleving_encoded, onderwerp_labels, beleving_labels


def print_sample_info(texts, onderwerp_encoded, beleving_encoded,
                     onderwerp_labels, beleving_labels, sample_idx=0):
    """
    Print information about a specific sample (useful for debugging).

    Args:
        All outputs from load_uwv_wim_dataset()
        sample_idx: Which sample to print (default: 0)
    """
    print(f"\n{'='*60}")
    print(f"SAMPLE {sample_idx}")
    print(f"{'='*60}")
    print(f"Text: {texts[sample_idx][:200]}...")
    print()

    # Get active onderwerp labels
    onderwerp_active = [onderwerp_labels[i] for i, val in enumerate(onderwerp_encoded[sample_idx]) if val == 1]
    print(f"Onderwerp labels ({len(onderwerp_active)}):")
    for label in onderwerp_active:
        print(f"  - {label}")
    print()

    # Get active beleving labels
    beleving_active = [beleving_labels[i] for i, val in enumerate(beleving_encoded[sample_idx]) if val == 1]
    print(f"Beleving labels ({len(beleving_active)}):")
    for label in beleving_active:
        print(f"  - {label}")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    # Test the loader
    print("Testing UWV dataset loader...\n")

    # Load small subset for testing
    texts, onderwerp, beleving, onderwerp_names, beleving_names = load_rd_wim_dataset(max_samples=10)

    # Print first sample
    print_sample_info(texts, onderwerp, beleving, onderwerp_names, beleving_names, sample_idx=0)

    # Print statistics
    print("\nDataset Statistics:")
    print(f"  Total samples: {len(texts)}")
    print(f"  Avg onderwerp labels per sample: {onderwerp.sum(axis=1).mean():.2f}")
    print(f"  Avg beleving labels per sample: {beleving.sum(axis=1).mean():.2f}")
    print(f"  Text length range: {min(len(t) for t in texts)} - {max(len(t) for t in texts)} chars")