CodeCraftLab / config.py
S-Dreamer's picture
Upload 3 files
178abc4 verified
"""
Training configuration schemas — Pydantic v2.
All training jobs are validated against these models before execution.
No raw dicts escape into the pipeline; everything is typed and constrained.
"""
from __future__ import annotations
from enum import StrEnum
from typing import Annotated
from pydantic import BaseModel, Field, HttpUrl, model_validator
from pydantic import PositiveFloat, PositiveInt
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class EvalStrategy(StrEnum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"
class Precision(StrEnum):
FP32 = "fp32"
FP16 = "fp16"
BF16 = "bf16"
INT8 = "int8"
class OptimizerType(StrEnum):
ADAMW = "adamw_torch"
ADAMW_8BIT = "adamw_8bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
SGD = "sgd"
class EvalMetric(StrEnum):
PASS_AT_1 = "pass_at_1"
PASS_AT_10 = "pass_at_10"
BLEU = "bleu"
EXECUTION_ACCURACY = "execution_accuracy"
EXACT_MATCH = "exact_match"
# ---------------------------------------------------------------------------
# Sub-configs
# ---------------------------------------------------------------------------
class LoRAConfig(BaseModel):
"""LoRA adapter configuration. Omit to disable LoRA (full fine-tune)."""
enabled: bool = True
r: Annotated[int, Field(ge=1, le=256)] = 16
alpha: Annotated[int, Field(ge=1)] = 32
dropout: Annotated[float, Field(ge=0.0, lt=1.0)] = 0.05
target_modules: list[str] = Field(
default_factory=lambda: ["q_proj", "v_proj"],
min_length=1,
)
bias: str = "none"
@model_validator(mode="after")
def alpha_geq_r(self) -> "LoRAConfig":
if self.alpha < self.r:
raise ValueError(f"lora.alpha ({self.alpha}) should be >= lora.r ({self.r})")
return self
class TrainingHyperparams(BaseModel):
num_epochs: Annotated[int, Field(ge=1, le=100)] = 3
batch_size: Annotated[int, Field(ge=1, le=256)] = 8
gradient_accumulation_steps: Annotated[int, Field(ge=1, le=128)] = 4
learning_rate: Annotated[float, Field(gt=0.0, lt=1.0)] = 2e-5
weight_decay: Annotated[float, Field(ge=0.0, lt=1.0)] = 0.01
warmup_ratio: Annotated[float, Field(ge=0.0, lt=1.0)] = 0.1
max_seq_length: Annotated[int, Field(ge=64, le=32768)] = 1024
max_grad_norm: Annotated[float, Field(gt=0.0)] = 1.0
optimizer: OptimizerType = OptimizerType.ADAMW
precision: Precision = Precision.BF16
lr_scheduler: str = "cosine"
seed: int = 42
dataloader_num_workers: Annotated[int, Field(ge=0, le=32)] = 4
@property
def effective_batch_size(self) -> int:
return self.batch_size * self.gradient_accumulation_steps
class EvaluationConfig(BaseModel):
enabled: bool = True
strategy: EvalStrategy = EvalStrategy.EPOCH
eval_steps: PositiveInt | None = None # required when strategy=STEPS
metrics: list[EvalMetric] = Field(
default_factory=lambda: [EvalMetric.PASS_AT_1, EvalMetric.BLEU]
)
num_samples_per_problem: Annotated[int, Field(ge=1, le=200)] = 10
timeout_seconds: Annotated[int, Field(ge=1, le=60)] = 10
load_best_model_at_end: bool = True
metric_for_best_model: EvalMetric = EvalMetric.PASS_AT_1
greater_is_better: bool = True
@model_validator(mode="after")
def eval_steps_required_for_steps_strategy(self) -> "EvaluationConfig":
if self.strategy == EvalStrategy.STEPS and self.eval_steps is None:
raise ValueError("evaluation.eval_steps is required when strategy='steps'")
return self
class CheckpointConfig(BaseModel):
save_strategy: EvalStrategy = EvalStrategy.EPOCH
save_steps: PositiveInt | None = None
save_total_limit: Annotated[int, Field(ge=1, le=20)] = 3
output_dir: str = "./checkpoints"
resume_from_checkpoint: str | None = None
@model_validator(mode="after")
def save_steps_required_for_steps_strategy(self) -> "CheckpointConfig":
if self.save_strategy == EvalStrategy.STEPS and self.save_steps is None:
raise ValueError("checkpoint.save_steps required when save_strategy='steps'")
return self
class HubConfig(BaseModel):
push_to_hub: bool = False
repo_id: str | None = None
private: bool = True
commit_message: str = "Training checkpoint"
@model_validator(mode="after")
def repo_id_required_if_pushing(self) -> "HubConfig":
if self.push_to_hub and not self.repo_id:
raise ValueError("hub.repo_id is required when hub.push_to_hub=true")
return self
class DatasetConfig(BaseModel):
dataset_id: str # internal UUID or HF Hub dataset path
split_ratio: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.9 # train split
max_samples: PositiveInt | None = None # None = use all
text_column: str = "content"
shuffle: bool = True
shuffle_seed: int = 42
# ---------------------------------------------------------------------------
# Root job config
# ---------------------------------------------------------------------------
class TrainingJobConfig(BaseModel):
"""
Complete training job specification.
Validated at job submission time. If validation passes, the job is
guaranteed to reach the pipeline with a coherent configuration.
"""
job_name: Annotated[str, Field(min_length=1, max_length=128, pattern=r"^[\w\-]+$")]
base_model: str = Field(
description="HuggingFace model ID or local path",
examples=["Salesforce/codegen-350M-mono", "deepseek-ai/deepseek-coder-1.3b-base"],
)
dataset: DatasetConfig
training: TrainingHyperparams = Field(default_factory=TrainingHyperparams)
lora: LoRAConfig | None = Field(default_factory=LoRAConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
checkpoint: CheckpointConfig = Field(default_factory=CheckpointConfig)
hub: HubConfig = Field(default_factory=HubConfig)
tags: list[str] = Field(default_factory=list, max_length=20)
notes: str | None = None
model_config = {
"json_schema_extra": {
"examples": [
{
"job_name": "codegen-finetune-v1",
"base_model": "Salesforce/codegen-350M-mono",
"dataset": {"dataset_id": "ds_abc123"},
"training": {
"num_epochs": 3,
"batch_size": 8,
"learning_rate": 2e-5,
},
"hub": {
"push_to_hub": True,
"repo_id": "your-org/codegen-finetune-v1",
},
}
]
}
}
# ---------------------------------------------------------------------------
# Inference config (served separately but validated here for consistency)
# ---------------------------------------------------------------------------
class InferenceConfig(BaseModel):
model_id: str
max_new_tokens: Annotated[int, Field(ge=1, le=4096)] = 256
temperature: Annotated[float, Field(ge=0.0, le=2.0)] = 0.2
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.95
top_k: Annotated[int, Field(ge=0, le=1000)] = 50
do_sample: bool = True
num_return_sequences: Annotated[int, Field(ge=1, le=200)] = 1
stop_sequences: list[str] = Field(default_factory=list)
precision: Precision = Precision.BF16