evalstate HF Staff commited on
Commit
838a6a1
Β·
verified Β·
1 Parent(s): a5c1053

Upload validate_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. validate_dataset.py +175 -0
validate_dataset.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "datasets>=2.14.0",
5
+ # ]
6
+ # ///
7
+ """
8
+ Validate dataset format for TRL training.
9
+
10
+ Usage:
11
+ python validate_dataset.py <dataset_name> <method>
12
+
13
+ Examples:
14
+ python validate_dataset.py trl-lib/Capybara sft
15
+ python validate_dataset.py Anthropic/hh-rlhf dpo
16
+ """
17
+
18
+ import sys
19
+ from datasets import load_dataset
20
+
21
+ def validate_sft_dataset(dataset):
22
+ """Validate SFT dataset format."""
23
+ print("πŸ” Validating SFT dataset...")
24
+
25
+ # Check for common fields
26
+ columns = dataset.column_names
27
+ print(f"πŸ“‹ Columns: {columns}")
28
+
29
+ has_messages = "messages" in columns
30
+ has_text = "text" in columns
31
+
32
+ if not (has_messages or has_text):
33
+ print("❌ Dataset must have 'messages' or 'text' field")
34
+ return False
35
+
36
+ # Check first example
37
+ example = dataset[0]
38
+
39
+ if has_messages:
40
+ messages = example["messages"]
41
+ if not isinstance(messages, list):
42
+ print("❌ 'messages' field must be a list")
43
+ return False
44
+
45
+ if len(messages) == 0:
46
+ print("❌ 'messages' field is empty")
47
+ return False
48
+
49
+ # Check message format
50
+ msg = messages[0]
51
+ if not isinstance(msg, dict):
52
+ print("❌ Messages must be dictionaries")
53
+ return False
54
+
55
+ if "role" not in msg or "content" not in msg:
56
+ print("❌ Messages must have 'role' and 'content' keys")
57
+ return False
58
+
59
+ print("βœ… Messages format valid")
60
+ print(f" First message: {msg['role']}: {msg['content'][:50]}...")
61
+
62
+ if has_text:
63
+ text = example["text"]
64
+ if not isinstance(text, str):
65
+ print("❌ 'text' field must be a string")
66
+ return False
67
+
68
+ if len(text) == 0:
69
+ print("❌ 'text' field is empty")
70
+ return False
71
+
72
+ print("βœ… Text format valid")
73
+ print(f" First text: {text[:100]}...")
74
+
75
+ return True
76
+
77
+ def validate_dpo_dataset(dataset):
78
+ """Validate DPO dataset format."""
79
+ print("πŸ” Validating DPO dataset...")
80
+
81
+ columns = dataset.column_names
82
+ print(f"πŸ“‹ Columns: {columns}")
83
+
84
+ required = ["prompt", "chosen", "rejected"]
85
+ missing = [col for col in required if col not in columns]
86
+
87
+ if missing:
88
+ print(f"❌ Missing required fields: {missing}")
89
+ return False
90
+
91
+ # Check first example
92
+ example = dataset[0]
93
+
94
+ for field in required:
95
+ value = example[field]
96
+ if isinstance(value, str):
97
+ if len(value) == 0:
98
+ print(f"❌ '{field}' field is empty")
99
+ return False
100
+ print(f"βœ… '{field}' format valid (string)")
101
+ elif isinstance(value, list):
102
+ if len(value) == 0:
103
+ print(f"❌ '{field}' field is empty")
104
+ return False
105
+ print(f"βœ… '{field}' format valid (list of messages)")
106
+ else:
107
+ print(f"❌ '{field}' must be string or list")
108
+ return False
109
+
110
+ return True
111
+
112
+ def validate_kto_dataset(dataset):
113
+ """Validate KTO dataset format."""
114
+ print("πŸ” Validating KTO dataset...")
115
+
116
+ columns = dataset.column_names
117
+ print(f"πŸ“‹ Columns: {columns}")
118
+
119
+ required = ["prompt", "completion", "label"]
120
+ missing = [col for col in required if col not in columns]
121
+
122
+ if missing:
123
+ print(f"❌ Missing required fields: {missing}")
124
+ return False
125
+
126
+ # Check first example
127
+ example = dataset[0]
128
+
129
+ if not isinstance(example["label"], bool):
130
+ print("❌ 'label' field must be boolean")
131
+ return False
132
+
133
+ print("βœ… KTO format valid")
134
+ return True
135
+
136
+ def main():
137
+ if len(sys.argv) != 3:
138
+ print("Usage: python validate_dataset.py <dataset_name> <method>")
139
+ print("Methods: sft, dpo, kto")
140
+ sys.exit(1)
141
+
142
+ dataset_name = sys.argv[1]
143
+ method = sys.argv[2].lower()
144
+
145
+ print(f"πŸ“¦ Loading dataset: {dataset_name}")
146
+ try:
147
+ dataset = load_dataset(dataset_name, split="train")
148
+ print(f"βœ… Dataset loaded: {len(dataset)} examples")
149
+ except Exception as e:
150
+ print(f"❌ Failed to load dataset: {e}")
151
+ sys.exit(1)
152
+
153
+ validators = {
154
+ "sft": validate_sft_dataset,
155
+ "dpo": validate_dpo_dataset,
156
+ "kto": validate_kto_dataset,
157
+ }
158
+
159
+ if method not in validators:
160
+ print(f"❌ Unknown method: {method}")
161
+ print(f"Supported methods: {list(validators.keys())}")
162
+ sys.exit(1)
163
+
164
+ validator = validators[method]
165
+ valid = validator(dataset)
166
+
167
+ if valid:
168
+ print(f"\nβœ… Dataset is valid for {method.upper()} training")
169
+ sys.exit(0)
170
+ else:
171
+ print(f"\n❌ Dataset is NOT valid for {method.upper()} training")
172
+ sys.exit(1)
173
+
174
+ if __name__ == "__main__":
175
+ main()