Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| GPT-OSS Training Script | |
| Specialized training script for OpenAI's GPT-OSS models | |
| Based on the GPT-OSS fine-tuning tutorial | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig, get_peft_model | |
| from trl import SFTTrainer, SFTConfig | |
| from datasets import load_dataset | |
| def load_gpt_oss_model_and_tokenizer(config): | |
| """Load GPT-OSS model and tokenizer with proper configuration""" | |
| print("Loading GPT-OSS tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| print("Loading GPT-OSS model with quantization...") | |
| # Import quantization config | |
| from transformers import BitsAndBytesConfig | |
| # Set up quantization config based on config | |
| if config.quantization_config and config.quantization_config.get("load_in_4bit"): | |
| # Use BitsAndBytesConfig for 4-bit quantization (memory optimized) | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| elif config.quantization_config and config.quantization_config.get("dequantize"): | |
| # Try to use Mxfp4Config if available (as per tutorial) | |
| try: | |
| from transformers import Mxfp4Config | |
| quantization_config = Mxfp4Config(dequantize=True) | |
| except ImportError: | |
| # Fallback to no quantization if Mxfp4Config not available | |
| print("Warning: Mxfp4Config not available, using no quantization") | |
| quantization_config = None | |
| else: | |
| # No quantization | |
| quantization_config = None | |
| # Model kwargs as per tutorial | |
| model_kwargs = { | |
| "attn_implementation": "eager", | |
| "torch_dtype": torch.bfloat16, | |
| "use_cache": False, | |
| "device_map": "auto", | |
| } | |
| # Only add quantization_config if it's not None | |
| if quantization_config is not None: | |
| model_kwargs["quantization_config"] = quantization_config | |
| model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs) | |
| return model, tokenizer | |
| def setup_lora_for_gpt_oss(model, config): | |
| """Setup LoRA for GPT-OSS model""" | |
| print("Setting up LoRA for GPT-OSS...") | |
| # LoRA configuration as per tutorial | |
| lora_config = LoraConfig( | |
| r=config.lora_config.get("r", 8) if config.lora_config else 8, | |
| lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16, | |
| target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear", | |
| target_parameters=config.lora_config.get("target_parameters", [ | |
| "7.mlp.experts.gate_up_proj", | |
| "7.mlp.experts.down_proj", | |
| "15.mlp.experts.gate_up_proj", | |
| "15.mlp.experts.down_proj", | |
| "23.mlp.experts.gate_up_proj", | |
| "23.mlp.experts.down_proj", | |
| ]) if config.lora_config else [ | |
| "7.mlp.experts.gate_up_proj", | |
| "7.mlp.experts.down_proj", | |
| "15.mlp.experts.gate_up_proj", | |
| "15.mlp.experts.down_proj", | |
| "23.mlp.experts.gate_up_proj", | |
| "23.mlp.experts.down_proj", | |
| ], | |
| ) | |
| peft_model = get_peft_model(model, lora_config) | |
| peft_model.print_trainable_parameters() | |
| return peft_model | |
| def load_dataset_from_config(config): | |
| """Load dataset based on configuration""" | |
| dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking') | |
| dataset_split = getattr(config, 'dataset_split', 'train') | |
| dataset_config = getattr(config, 'dataset_config', None) | |
| print(f"Loading dataset: {dataset_name}") | |
| print(f"Dataset split: {dataset_split}") | |
| if dataset_config: | |
| print(f"Dataset config: {dataset_config}") | |
| # Load the dataset | |
| if dataset_config: | |
| dataset = load_dataset(dataset_name, dataset_config, split=dataset_split) | |
| else: | |
| dataset = load_dataset(dataset_name, split=dataset_split) | |
| print(f"Original dataset size: {len(dataset)} examples") | |
| # Apply filtering based on configuration | |
| dataset = apply_dataset_filtering(dataset, config) | |
| # Apply dataset processing based on format | |
| dataset = process_dataset_format(dataset, config) | |
| print(f"Final dataset size: {len(dataset)} examples") | |
| return dataset | |
| def apply_dataset_filtering(dataset, config): | |
| """Apply filtering based on configuration""" | |
| # Filter bad entries if specified | |
| if getattr(config, 'filter_bad_entries', False): | |
| bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry') | |
| bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected') | |
| bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected') | |
| original_size = len(dataset) | |
| # Filter out bad entries | |
| if bad_entry_field in dataset.column_names: | |
| dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False)) | |
| print(f"Filtered {original_size - len(dataset)} bad entries") | |
| # Filter out bad prompts | |
| if bad_prompt_field in dataset.column_names: | |
| dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False)) | |
| print(f"Filtered bad prompts, remaining: {len(dataset)} examples") | |
| # Filter out bad responses | |
| if bad_response_field in dataset.column_names: | |
| dataset = dataset.filter(lambda x: not x.get(bad_response_field, False)) | |
| print(f"Filtered bad responses, remaining: {len(dataset)} examples") | |
| # Apply length filtering | |
| min_length = getattr(config, 'min_length', 10) | |
| max_length = getattr(config, 'max_length', None) | |
| input_field = getattr(config, 'input_field', 'prompt') | |
| target_field = getattr(config, 'target_field', 'accepted_completion') | |
| if min_length > 0 or max_length: | |
| def length_filter(example): | |
| input_len = len(example.get(input_field, '')) | |
| target_len = len(example.get(target_field, '')) | |
| total_len = input_len + target_len | |
| if total_len < min_length: | |
| return False | |
| if max_length and total_len > max_length: | |
| return False | |
| return True | |
| original_size = len(dataset) | |
| dataset = dataset.filter(length_filter) | |
| print(f"Length filtering: {original_size} -> {len(dataset)} examples") | |
| # Apply sampling if specified | |
| max_samples = getattr(config, 'max_samples', None) | |
| if max_samples and len(dataset) > max_samples: | |
| dataset = dataset.shuffle(seed=42).select(range(max_samples)) | |
| print(f"Sampled {max_samples} examples from dataset") | |
| return dataset | |
| def format_gpt_oss_harmony(prompt, completion, add_eos_token=True): | |
| """ | |
| Format data for GPT-OSS Harmony format following the exact template structure. | |
| Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja | |
| """ | |
| # GPT-OSS Harmony format structure (exact template compliance) | |
| # User message: <|start|>user<|message|>content<|end|> | |
| # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference) | |
| # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training) | |
| harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}" | |
| if add_eos_token: | |
| # Use <|return|> for training as per template specification | |
| # This indicates the end of generation in training | |
| harmony_text += "<|return|>" | |
| else: | |
| # Use <|end|> for inference | |
| harmony_text += "<|end|>" | |
| return harmony_text | |
| def process_dataset_format(dataset, config): | |
| """Process dataset based on format configuration with exact GPT-OSS Harmony compliance""" | |
| dataset_format = getattr(config, 'dataset_format', 'openhermes_fr') | |
| input_field = getattr(config, 'input_field', 'prompt') | |
| target_field = getattr(config, 'target_field', 'accepted_completion') | |
| concatenate_fields = getattr(config, 'concatenate_fields', True) | |
| field_separator = getattr(config, 'field_separator', '\n\n### Response:\n') | |
| add_eos_token = getattr(config, 'add_eos_token', True) | |
| use_harmony_format = getattr(config, 'use_harmony_format', True) | |
| print(f"Processing dataset format: {dataset_format}") | |
| print(f"Input field: {input_field}, Target field: {target_field}") | |
| print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}") | |
| if dataset_format == "openhermes_fr": | |
| # Process OpenHermes-FR format: prompt + accepted_completion | |
| def format_openhermes_fr(example): | |
| prompt = example.get(input_field, '') | |
| completion = example.get(target_field, '') | |
| if concatenate_fields: | |
| if use_harmony_format: | |
| # Use exact GPT-OSS Harmony format from template | |
| text = format_gpt_oss_harmony(prompt, completion, add_eos_token) | |
| else: | |
| # Fallback to standard format with separator | |
| text = prompt + field_separator + completion | |
| if add_eos_token: | |
| text += "</s>" | |
| return {"text": text} | |
| else: | |
| # Keep separate for more advanced training setups | |
| return { | |
| "input": prompt, | |
| "output": completion | |
| } | |
| dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names) | |
| elif dataset_format == "messages": | |
| # Process messages format (like HuggingFaceH4/Multilingual-Thinking) | |
| def format_messages(example): | |
| messages = example.get(input_field, []) | |
| if use_harmony_format and len(messages) >= 2: | |
| # Extract user and assistant messages for harmony format | |
| user_message = "" | |
| assistant_message = "" | |
| for message in messages: | |
| role = message.get("role", "") | |
| content = message.get("content", "") | |
| if role == "user": | |
| user_message = content | |
| elif role == "assistant": | |
| assistant_message = content | |
| if user_message and assistant_message: | |
| # Use GPT-OSS Harmony format | |
| text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token) | |
| else: | |
| # Fallback to simple concatenation | |
| text = "" | |
| for message in messages: | |
| role = message.get("role", "") | |
| content = message.get("content", "") | |
| text += f"{role}: {content}\n" | |
| if add_eos_token: | |
| text += "</s>" | |
| else: | |
| # Standard format - convert messages to simple text | |
| text = "" | |
| for message in messages: | |
| role = message.get("role", "") | |
| content = message.get("content", "") | |
| text += f"{role}: {content}\n" | |
| if add_eos_token: | |
| text += "</s>" | |
| return {"text": text} | |
| dataset = dataset.map(format_messages, remove_columns=dataset.column_names) | |
| elif dataset_format == "text": | |
| # Process plain text format | |
| text_field = input_field | |
| def format_text(example): | |
| text = example.get(text_field, '') | |
| if add_eos_token: | |
| text += "</s>" | |
| return {"text": text} | |
| dataset = dataset.map(format_text, remove_columns=dataset.column_names) | |
| elif dataset_format == "custom": | |
| # Custom format - user handles this in their config | |
| print("Using custom dataset format - no automatic processing") | |
| return dataset | |
| def setup_trackio_tracking(config): | |
| """Setup Trackio tracking if enabled""" | |
| if not config.enable_tracking or not config.trackio_url: | |
| print("Trackio tracking disabled or URL not provided") | |
| return None | |
| print(f"Setting up Trackio tracking: {config.trackio_url}") | |
| # Import the correct TrackioAPIClient | |
| import sys | |
| import os | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic')) | |
| from trackio_api_client import TrackioAPIClient | |
| # Initialize Trackio client using the correct API | |
| trackio_client = TrackioAPIClient( | |
| space_id=config.trackio_url, | |
| hf_token=config.trackio_token | |
| ) | |
| return trackio_client | |
| def create_sft_config(config, output_dir): | |
| """Create enhanced SFTConfig for GPT-OSS training""" | |
| print("Creating enhanced SFT configuration...") | |
| # Extract training parameters from config with enhanced defaults | |
| num_train_epochs = getattr(config, 'num_train_epochs', 1.0) | |
| max_steps = getattr(config, 'max_steps', None) | |
| warmup_ratio = getattr(config, 'warmup_ratio', 0.03) | |
| warmup_steps = getattr(config, 'warmup_steps', None) | |
| # Learning rate configuration | |
| learning_rate = config.learning_rate | |
| lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr') | |
| lr_scheduler_kwargs = getattr(config, 'lr_scheduler_kwargs', {"min_lr_rate": 0.1}) | |
| # Batch configuration | |
| per_device_train_batch_size = config.batch_size | |
| per_device_eval_batch_size = getattr(config, 'eval_batch_size', config.batch_size) | |
| gradient_accumulation_steps = config.gradient_accumulation_steps | |
| # Evaluation and logging | |
| eval_strategy = getattr(config, 'eval_strategy', 'steps') | |
| eval_steps = getattr(config, 'eval_steps', 100) | |
| logging_steps = getattr(config, 'logging_steps', 10) | |
| # Saving configuration | |
| save_strategy = getattr(config, 'save_strategy', 'steps') | |
| save_steps = getattr(config, 'save_steps', 500) | |
| save_total_limit = getattr(config, 'save_total_limit', 3) | |
| # Mixed precision | |
| fp16 = getattr(config, 'fp16', False) | |
| bf16 = getattr(config, 'bf16', True) | |
| # Regularization | |
| weight_decay = getattr(config, 'weight_decay', 0.01) | |
| max_grad_norm = getattr(config, 'max_grad_norm', 1.0) | |
| # HuggingFace Hub integration | |
| push_to_hub = getattr(config, 'push_to_hub', False) | |
| print(f" • Epochs: {num_train_epochs}") | |
| print(f" • Learning rate: {learning_rate}") | |
| print(f" • Batch size: {per_device_train_batch_size}") | |
| print(f" • Gradient accumulation: {gradient_accumulation_steps}") | |
| print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}") | |
| sft_config = SFTConfig( | |
| # Training duration | |
| num_train_epochs=num_train_epochs, | |
| max_steps=max_steps, | |
| # Learning rate | |
| learning_rate=learning_rate, | |
| lr_scheduler_type=lr_scheduler_type, | |
| lr_scheduler_kwargs=lr_scheduler_kwargs, | |
| warmup_ratio=warmup_ratio, | |
| warmup_steps=warmup_steps, | |
| # Batch configuration | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| per_device_eval_batch_size=per_device_eval_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| # Model configuration | |
| max_seq_length=config.max_seq_length, | |
| gradient_checkpointing=getattr(config, 'use_gradient_checkpointing', True), | |
| # Mixed precision | |
| fp16=fp16, | |
| bf16=bf16, | |
| # Regularization | |
| weight_decay=weight_decay, | |
| max_grad_norm=max_grad_norm, | |
| # Evaluation | |
| evaluation_strategy=eval_strategy, | |
| eval_steps=eval_steps, | |
| # Logging | |
| logging_steps=logging_steps, | |
| # Saving | |
| save_strategy=save_strategy, | |
| save_steps=save_steps, | |
| save_total_limit=save_total_limit, | |
| # Output | |
| output_dir=output_dir, | |
| # Data loading | |
| dataloader_num_workers=getattr(config, 'dataloader_num_workers', 4), | |
| dataloader_pin_memory=getattr(config, 'dataloader_pin_memory', True), | |
| # Performance | |
| group_by_length=getattr(config, 'group_by_length', True), | |
| remove_unused_columns=getattr(config, 'remove_unused_columns', True), | |
| # HuggingFace Hub | |
| push_to_hub=push_to_hub, | |
| # Monitoring | |
| report_to="trackio" if getattr(config, 'enable_tracking', False) else None, | |
| ) | |
| return sft_config | |
| def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"): | |
| """Main training function for GPT-OSS""" | |
| print("=== GPT-OSS Training Pipeline ===") | |
| print(f"Config: {config_path}") | |
| print(f"Experiment: {experiment_name}") | |
| print(f"Output: {output_dir}") | |
| print(f"Trackio: {trackio_url}") | |
| print(f"Trainer: {trainer_type}") | |
| # Load configuration | |
| if os.path.exists(config_path): | |
| import importlib.util | |
| spec = importlib.util.spec_from_file_location("config_module", config_path) | |
| config_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(config_module) | |
| if hasattr(config_module, 'config'): | |
| config = config_module.config | |
| else: | |
| # Try to find a config class | |
| for attr_name in dir(config_module): | |
| attr = getattr(config_module, attr_name) | |
| if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name): | |
| config = attr | |
| break | |
| else: | |
| raise ValueError(f"No GPT-OSS configuration found in {config_path}") | |
| else: | |
| raise FileNotFoundError(f"Configuration file not found: {config_path}") | |
| # Update config with runtime parameters | |
| config.experiment_name = experiment_name | |
| config.trackio_url = trackio_url | |
| config.trainer_type = trainer_type | |
| # Load model and tokenizer | |
| model, tokenizer = load_gpt_oss_model_and_tokenizer(config) | |
| # Setup LoRA | |
| peft_model = setup_lora_for_gpt_oss(model, config) | |
| # Load dataset | |
| dataset = load_dataset_from_config(config) | |
| # Setup Trackio tracking | |
| trackio_client = setup_trackio_tracking(config) | |
| # Create SFT configuration | |
| sft_config = create_sft_config(config, output_dir) | |
| # Create trainer | |
| print("Creating SFT trainer...") | |
| trainer = SFTTrainer( | |
| model=peft_model, | |
| args=sft_config, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| ) | |
| # Start training | |
| print("Starting GPT-OSS training...") | |
| trainer.train() | |
| # Save model | |
| print("Saving trained model...") | |
| trainer.save_model(output_dir) | |
| # Push to hub if enabled | |
| if sft_config.push_to_hub: | |
| print("Pushing model to Hugging Face Hub...") | |
| trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking") | |
| print("GPT-OSS training completed successfully!") | |
| return trainer | |
| def main(): | |
| parser = argparse.ArgumentParser(description="GPT-OSS Training Script") | |
| parser.add_argument("--config", required=True, help="Path to configuration file") | |
| parser.add_argument("--experiment-name", required=True, help="Experiment name") | |
| parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints") | |
| parser.add_argument("--trackio-url", help="Trackio URL for monitoring") | |
| parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type") | |
| args = parser.parse_args() | |
| # Validate arguments | |
| if not os.path.exists(args.config): | |
| print(f"Error: Configuration file not found: {args.config}") | |
| sys.exit(1) | |
| # Create output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| try: | |
| train_gpt_oss( | |
| config_path=args.config, | |
| experiment_name=args.experiment_name, | |
| output_dir=args.output_dir, | |
| trackio_url=args.trackio_url, | |
| trainer_type=args.trainer_type | |
| ) | |
| except Exception as e: | |
| print(f"Error during training: {e}") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |