SmolFactory / scripts /training /train_gpt_oss.py
Tonic's picture
adds harmony format , configurable gpt-oss parameters, launch.sh logic , improved templates for legml gpt-oss training, dynamic results directory and improve model pushing
59e57ff
raw
history blame
20.9 kB
#!/usr/bin/env python3
"""
GPT-OSS Training Script
Specialized training script for OpenAI's GPT-OSS models
Based on the GPT-OSS fine-tuning tutorial
"""
import os
import sys
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
def load_gpt_oss_model_and_tokenizer(config):
"""Load GPT-OSS model and tokenizer with proper configuration"""
print("Loading GPT-OSS tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
print("Loading GPT-OSS model with quantization...")
# Import quantization config
from transformers import BitsAndBytesConfig
# Set up quantization config based on config
if config.quantization_config and config.quantization_config.get("load_in_4bit"):
# Use BitsAndBytesConfig for 4-bit quantization (memory optimized)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif config.quantization_config and config.quantization_config.get("dequantize"):
# Try to use Mxfp4Config if available (as per tutorial)
try:
from transformers import Mxfp4Config
quantization_config = Mxfp4Config(dequantize=True)
except ImportError:
# Fallback to no quantization if Mxfp4Config not available
print("Warning: Mxfp4Config not available, using no quantization")
quantization_config = None
else:
# No quantization
quantization_config = None
# Model kwargs as per tutorial
model_kwargs = {
"attn_implementation": "eager",
"torch_dtype": torch.bfloat16,
"use_cache": False,
"device_map": "auto",
}
# Only add quantization_config if it's not None
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
return model, tokenizer
def setup_lora_for_gpt_oss(model, config):
"""Setup LoRA for GPT-OSS model"""
print("Setting up LoRA for GPT-OSS...")
# LoRA configuration as per tutorial
lora_config = LoraConfig(
r=config.lora_config.get("r", 8) if config.lora_config else 8,
lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16,
target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear",
target_parameters=config.lora_config.get("target_parameters", [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
]) if config.lora_config else [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
],
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
return peft_model
def load_dataset_from_config(config):
"""Load dataset based on configuration"""
dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking')
dataset_split = getattr(config, 'dataset_split', 'train')
dataset_config = getattr(config, 'dataset_config', None)
print(f"Loading dataset: {dataset_name}")
print(f"Dataset split: {dataset_split}")
if dataset_config:
print(f"Dataset config: {dataset_config}")
# Load the dataset
if dataset_config:
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split)
else:
dataset = load_dataset(dataset_name, split=dataset_split)
print(f"Original dataset size: {len(dataset)} examples")
# Apply filtering based on configuration
dataset = apply_dataset_filtering(dataset, config)
# Apply dataset processing based on format
dataset = process_dataset_format(dataset, config)
print(f"Final dataset size: {len(dataset)} examples")
return dataset
def apply_dataset_filtering(dataset, config):
"""Apply filtering based on configuration"""
# Filter bad entries if specified
if getattr(config, 'filter_bad_entries', False):
bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry')
bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected')
bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected')
original_size = len(dataset)
# Filter out bad entries
if bad_entry_field in dataset.column_names:
dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False))
print(f"Filtered {original_size - len(dataset)} bad entries")
# Filter out bad prompts
if bad_prompt_field in dataset.column_names:
dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False))
print(f"Filtered bad prompts, remaining: {len(dataset)} examples")
# Filter out bad responses
if bad_response_field in dataset.column_names:
dataset = dataset.filter(lambda x: not x.get(bad_response_field, False))
print(f"Filtered bad responses, remaining: {len(dataset)} examples")
# Apply length filtering
min_length = getattr(config, 'min_length', 10)
max_length = getattr(config, 'max_length', None)
input_field = getattr(config, 'input_field', 'prompt')
target_field = getattr(config, 'target_field', 'accepted_completion')
if min_length > 0 or max_length:
def length_filter(example):
input_len = len(example.get(input_field, ''))
target_len = len(example.get(target_field, ''))
total_len = input_len + target_len
if total_len < min_length:
return False
if max_length and total_len > max_length:
return False
return True
original_size = len(dataset)
dataset = dataset.filter(length_filter)
print(f"Length filtering: {original_size} -> {len(dataset)} examples")
# Apply sampling if specified
max_samples = getattr(config, 'max_samples', None)
if max_samples and len(dataset) > max_samples:
dataset = dataset.shuffle(seed=42).select(range(max_samples))
print(f"Sampled {max_samples} examples from dataset")
return dataset
def format_gpt_oss_harmony(prompt, completion, add_eos_token=True):
"""
Format data for GPT-OSS Harmony format following the exact template structure.
Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja
"""
# GPT-OSS Harmony format structure (exact template compliance)
# User message: <|start|>user<|message|>content<|end|>
# Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference)
# Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training)
harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}"
if add_eos_token:
# Use <|return|> for training as per template specification
# This indicates the end of generation in training
harmony_text += "<|return|>"
else:
# Use <|end|> for inference
harmony_text += "<|end|>"
return harmony_text
def process_dataset_format(dataset, config):
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
dataset_format = getattr(config, 'dataset_format', 'openhermes_fr')
input_field = getattr(config, 'input_field', 'prompt')
target_field = getattr(config, 'target_field', 'accepted_completion')
concatenate_fields = getattr(config, 'concatenate_fields', True)
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
add_eos_token = getattr(config, 'add_eos_token', True)
use_harmony_format = getattr(config, 'use_harmony_format', True)
print(f"Processing dataset format: {dataset_format}")
print(f"Input field: {input_field}, Target field: {target_field}")
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
if dataset_format == "openhermes_fr":
# Process OpenHermes-FR format: prompt + accepted_completion
def format_openhermes_fr(example):
prompt = example.get(input_field, '')
completion = example.get(target_field, '')
if concatenate_fields:
if use_harmony_format:
# Use exact GPT-OSS Harmony format from template
text = format_gpt_oss_harmony(prompt, completion, add_eos_token)
else:
# Fallback to standard format with separator
text = prompt + field_separator + completion
if add_eos_token:
text += "</s>"
return {"text": text}
else:
# Keep separate for more advanced training setups
return {
"input": prompt,
"output": completion
}
dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names)
elif dataset_format == "messages":
# Process messages format (like HuggingFaceH4/Multilingual-Thinking)
def format_messages(example):
messages = example.get(input_field, [])
if use_harmony_format and len(messages) >= 2:
# Extract user and assistant messages for harmony format
user_message = ""
assistant_message = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "user":
user_message = content
elif role == "assistant":
assistant_message = content
if user_message and assistant_message:
# Use GPT-OSS Harmony format
text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token)
else:
# Fallback to simple concatenation
text = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
text += f"{role}: {content}\n"
if add_eos_token:
text += "</s>"
else:
# Standard format - convert messages to simple text
text = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
text += f"{role}: {content}\n"
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_messages, remove_columns=dataset.column_names)
elif dataset_format == "text":
# Process plain text format
text_field = input_field
def format_text(example):
text = example.get(text_field, '')
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_text, remove_columns=dataset.column_names)
elif dataset_format == "custom":
# Custom format - user handles this in their config
print("Using custom dataset format - no automatic processing")
return dataset
def setup_trackio_tracking(config):
"""Setup Trackio tracking if enabled"""
if not config.enable_tracking or not config.trackio_url:
print("Trackio tracking disabled or URL not provided")
return None
print(f"Setting up Trackio tracking: {config.trackio_url}")
# Import the correct TrackioAPIClient
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic'))
from trackio_api_client import TrackioAPIClient
# Initialize Trackio client using the correct API
trackio_client = TrackioAPIClient(
space_id=config.trackio_url,
hf_token=config.trackio_token
)
return trackio_client
def create_sft_config(config, output_dir):
"""Create enhanced SFTConfig for GPT-OSS training"""
print("Creating enhanced SFT configuration...")
# Extract training parameters from config with enhanced defaults
num_train_epochs = getattr(config, 'num_train_epochs', 1.0)
max_steps = getattr(config, 'max_steps', None)
warmup_ratio = getattr(config, 'warmup_ratio', 0.03)
warmup_steps = getattr(config, 'warmup_steps', None)
# Learning rate configuration
learning_rate = config.learning_rate
lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
lr_scheduler_kwargs = getattr(config, 'lr_scheduler_kwargs', {"min_lr_rate": 0.1})
# Batch configuration
per_device_train_batch_size = config.batch_size
per_device_eval_batch_size = getattr(config, 'eval_batch_size', config.batch_size)
gradient_accumulation_steps = config.gradient_accumulation_steps
# Evaluation and logging
eval_strategy = getattr(config, 'eval_strategy', 'steps')
eval_steps = getattr(config, 'eval_steps', 100)
logging_steps = getattr(config, 'logging_steps', 10)
# Saving configuration
save_strategy = getattr(config, 'save_strategy', 'steps')
save_steps = getattr(config, 'save_steps', 500)
save_total_limit = getattr(config, 'save_total_limit', 3)
# Mixed precision
fp16 = getattr(config, 'fp16', False)
bf16 = getattr(config, 'bf16', True)
# Regularization
weight_decay = getattr(config, 'weight_decay', 0.01)
max_grad_norm = getattr(config, 'max_grad_norm', 1.0)
# HuggingFace Hub integration
push_to_hub = getattr(config, 'push_to_hub', False)
print(f" • Epochs: {num_train_epochs}")
print(f" • Learning rate: {learning_rate}")
print(f" • Batch size: {per_device_train_batch_size}")
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
sft_config = SFTConfig(
# Training duration
num_train_epochs=num_train_epochs,
max_steps=max_steps,
# Learning rate
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
lr_scheduler_kwargs=lr_scheduler_kwargs,
warmup_ratio=warmup_ratio,
warmup_steps=warmup_steps,
# Batch configuration
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# Model configuration
max_seq_length=config.max_seq_length,
gradient_checkpointing=getattr(config, 'use_gradient_checkpointing', True),
# Mixed precision
fp16=fp16,
bf16=bf16,
# Regularization
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
# Evaluation
evaluation_strategy=eval_strategy,
eval_steps=eval_steps,
# Logging
logging_steps=logging_steps,
# Saving
save_strategy=save_strategy,
save_steps=save_steps,
save_total_limit=save_total_limit,
# Output
output_dir=output_dir,
# Data loading
dataloader_num_workers=getattr(config, 'dataloader_num_workers', 4),
dataloader_pin_memory=getattr(config, 'dataloader_pin_memory', True),
# Performance
group_by_length=getattr(config, 'group_by_length', True),
remove_unused_columns=getattr(config, 'remove_unused_columns', True),
# HuggingFace Hub
push_to_hub=push_to_hub,
# Monitoring
report_to="trackio" if getattr(config, 'enable_tracking', False) else None,
)
return sft_config
def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"):
"""Main training function for GPT-OSS"""
print("=== GPT-OSS Training Pipeline ===")
print(f"Config: {config_path}")
print(f"Experiment: {experiment_name}")
print(f"Output: {output_dir}")
print(f"Trackio: {trackio_url}")
print(f"Trainer: {trainer_type}")
# Load configuration
if os.path.exists(config_path):
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
config = config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name):
config = attr
break
else:
raise ValueError(f"No GPT-OSS configuration found in {config_path}")
else:
raise FileNotFoundError(f"Configuration file not found: {config_path}")
# Update config with runtime parameters
config.experiment_name = experiment_name
config.trackio_url = trackio_url
config.trainer_type = trainer_type
# Load model and tokenizer
model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
# Setup LoRA
peft_model = setup_lora_for_gpt_oss(model, config)
# Load dataset
dataset = load_dataset_from_config(config)
# Setup Trackio tracking
trackio_client = setup_trackio_tracking(config)
# Create SFT configuration
sft_config = create_sft_config(config, output_dir)
# Create trainer
print("Creating SFT trainer...")
trainer = SFTTrainer(
model=peft_model,
args=sft_config,
train_dataset=dataset,
processing_class=tokenizer,
)
# Start training
print("Starting GPT-OSS training...")
trainer.train()
# Save model
print("Saving trained model...")
trainer.save_model(output_dir)
# Push to hub if enabled
if sft_config.push_to_hub:
print("Pushing model to Hugging Face Hub...")
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
print("GPT-OSS training completed successfully!")
return trainer
def main():
parser = argparse.ArgumentParser(description="GPT-OSS Training Script")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument("--experiment-name", required=True, help="Experiment name")
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
args = parser.parse_args()
# Validate arguments
if not os.path.exists(args.config):
print(f"Error: Configuration file not found: {args.config}")
sys.exit(1)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
try:
train_gpt_oss(
config_path=args.config,
experiment_name=args.experiment_name,
output_dir=args.output_dir,
trackio_url=args.trackio_url,
trainer_type=args.trainer_type
)
except Exception as e:
print(f"Error during training: {e}")
sys.exit(1)
if __name__ == "__main__":
main()