Forgekit / forgekit /config_generator.py
AIencoder's picture
Rename config_generator.py to forgekit/config_generator.py
2275b0c verified
"""Merge configuration YAML generator with presets and validation."""
from dataclasses import dataclass, field
from typing import Optional
import yaml
# ===== MERGE METHOD DEFINITIONS =====
MERGE_METHODS = {
"dare_ties": {
"name": "DARE-TIES",
"description": "Drop And REscale with TIES — trims low-magnitude parameters and resolves sign conflicts. Best for combining 2+ specialist models.",
"min_models": 2,
"max_models": 10,
"needs_base": True,
"params": ["weight", "density"],
"global_params": ["int8_mask", "normalize"],
"supports_slices": True,
},
"ties": {
"name": "TIES",
"description": "Trim, Elect Sign, Merge — resolves parameter interference between models. Similar to DARE-TIES but without the drop step.",
"min_models": 2,
"max_models": 10,
"needs_base": True,
"params": ["weight", "density"],
"global_params": ["int8_mask", "normalize"],
"supports_slices": True,
},
"slerp": {
"name": "SLERP",
"description": "Spherical Linear Interpolation — smoothly blends two models along a curved path in weight space. Best for two-model merges.",
"min_models": 2,
"max_models": 2,
"needs_base": False,
"params": [],
"global_params": ["t"],
"supports_slices": True,
},
"linear": {
"name": "Linear",
"description": "Simple weighted average of model parameters. Fast and predictable baseline.",
"min_models": 2,
"max_models": 10,
"needs_base": False,
"params": ["weight"],
"global_params": ["normalize"],
"supports_slices": True,
},
"task_arithmetic": {
"name": "Task Arithmetic",
"description": "Add or subtract task vectors from a base model. Use negative weights to remove capabilities.",
"min_models": 1,
"max_models": 10,
"needs_base": True,
"params": ["weight"],
"global_params": [],
"supports_slices": False,
},
"passthrough": {
"name": "Passthrough (Frankenmerge)",
"description": "Stack layers from different models. Can create larger models from smaller ones. Supports different layer counts.",
"min_models": 1,
"max_models": 10,
"needs_base": False,
"params": [],
"global_params": [],
"supports_slices": True,
"requires_slices": True,
},
}
# ===== PRESETS =====
@dataclass
class MergePreset:
name: str
description: str
method: str
weight_strategy: str # "equal", "first_dominant", "last_dominant", "auto_detect"
def apply(self, model_ids: list[str]) -> tuple[list[float], list[float]]:
"""Generate weights and densities for given models."""
n = len(model_ids)
if n == 0:
return [], []
if self.weight_strategy == "equal":
weights = [round(1.0 / n, 3)] * n
densities = [0.6] * n
elif self.weight_strategy == "first_dominant":
weights = [0.6] + [round(0.4 / (n - 1), 3)] * (n - 1) if n > 1 else [1.0]
densities = [0.7] + [0.5] * (n - 1)
elif self.weight_strategy == "last_dominant":
weights = [round(0.4 / (n - 1), 3)] * (n - 1) + [0.6] if n > 1 else [1.0]
densities = [0.5] * (n - 1) + [0.7]
elif self.weight_strategy == "auto_detect":
weights, densities = _auto_detect_weights(model_ids)
else:
weights = [round(1.0 / n, 3)] * n
densities = [0.6] * n
return weights, densities
def _auto_detect_weights(model_ids: list[str]) -> tuple[list[float], list[float]]:
"""Auto-detect optimal weights based on model names/tags."""
n = len(model_ids)
weights = []
densities = []
for mid in model_ids:
name = mid.lower()
if "code" in name or "coder" in name:
weights.append(0.5)
densities.append(0.7)
elif "math" in name:
weights.append(0.4)
densities.append(0.6)
elif "instruct" in name and "code" not in name:
weights.append(0.3)
densities.append(0.5)
else:
weights.append(0.3)
densities.append(0.5)
# Normalize weights to sum to 1
total = sum(weights)
if total > 0:
weights = [round(w / total, 3) for w in weights]
return weights, densities
PRESETS = {
"equal": MergePreset("Equal", "Equal weights for all models", "dare_ties", "equal"),
"first_dominant": MergePreset("First Model Dominant", "Prioritize the first model", "dare_ties", "first_dominant"),
"last_dominant": MergePreset("Last Model Dominant", "Prioritize the last model", "dare_ties", "last_dominant"),
"coding_focus": MergePreset("Coding Focus", "Higher weight for code-related models", "dare_ties", "auto_detect"),
"balanced_slerp": MergePreset("Balanced SLERP", "50/50 interpolation between two models", "slerp", "equal"),
}
# ===== CONFIG GENERATION =====
@dataclass
class MergeConfig:
"""Complete merge configuration."""
method: str = "dare_ties"
models: list[str] = field(default_factory=list)
base_model: str = ""
weights: list[float] = field(default_factory=list)
densities: list[float] = field(default_factory=list)
tokenizer_source: str = ""
dtype: str = "bfloat16"
# Method-specific params
slerp_t: float = 0.5
int8_mask: bool = True
normalize: bool = True
# Passthrough/slice params
slices: list[dict] = field(default_factory=list)
# Output
output_name: str = ""
def validate(self) -> list[str]:
"""Validate the configuration. Returns list of error messages."""
errors = []
method_info = MERGE_METHODS.get(self.method)
if not method_info:
errors.append(f"Unknown merge method: {self.method}")
return errors
n = len(self.models)
if n < method_info["min_models"]:
errors.append(f"{method_info['name']} requires at least {method_info['min_models']} models")
if n > method_info["max_models"]:
errors.append(f"{method_info['name']} supports at most {method_info['max_models']} models")
if method_info["needs_base"] and not self.base_model:
errors.append(f"{method_info['name']} requires a base_model")
if "weight" in method_info["params"]:
if self.weights and len(self.weights) != n:
errors.append(f"Expected {n} weights, got {len(self.weights)}")
if self.weights and any(w < -1 or w > 2 for w in self.weights):
errors.append("Weights should be between -1 and 2")
if "density" in method_info["params"]:
if self.densities and len(self.densities) != n:
errors.append(f"Expected {n} densities, got {len(self.densities)}")
if self.densities and any(d < 0 or d > 1 for d in self.densities):
errors.append("Densities must be between 0 and 1")
if self.method == "slerp" and (self.slerp_t < 0 or self.slerp_t > 1):
errors.append("SLERP t parameter must be between 0 and 1")
if method_info.get("requires_slices") and not self.slices:
errors.append(f"{method_info['name']} requires slice definitions")
return errors
def generate_yaml(config: MergeConfig) -> str:
"""Generate mergekit-compatible YAML configuration.
Args:
config: MergeConfig with all parameters
Returns:
YAML string ready for mergekit
"""
errors = config.validate()
if errors:
return f"# VALIDATION ERRORS:\n" + "\n".join(f"# - {e}" for e in errors)
method_info = MERGE_METHODS[config.method]
doc = {}
# Passthrough uses slices format
if config.method == "passthrough":
doc["slices"] = config.slices or _default_slices(config)
doc["merge_method"] = config.method
doc["dtype"] = config.dtype
return yaml.dump(doc, default_flow_style=False, sort_keys=False)
# Standard methods
doc["merge_method"] = config.method
if method_info["needs_base"]:
doc["base_model"] = config.base_model
# Models with parameters
if config.method == "slerp":
doc["models"] = [{"model": m} for m in config.models]
doc["parameters"] = {"t": config.slerp_t}
else:
models_list = []
for i, model_id in enumerate(config.models):
entry = {"model": model_id}
params = {}
if "weight" in method_info["params"] and config.weights:
params["weight"] = config.weights[i]
if "density" in method_info["params"] and config.densities:
params["density"] = config.densities[i]
if params:
entry["parameters"] = params
models_list.append(entry)
doc["models"] = models_list
# Global parameters
global_params = {}
if "int8_mask" in method_info.get("global_params", []):
global_params["int8_mask"] = config.int8_mask
if "normalize" in method_info.get("global_params", []):
global_params["normalize"] = config.normalize
if global_params:
doc["parameters"] = global_params
doc["dtype"] = config.dtype
if config.tokenizer_source:
doc["tokenizer_source"] = config.tokenizer_source
return yaml.dump(doc, default_flow_style=False, sort_keys=False)
def _default_slices(config: MergeConfig) -> list[dict]:
"""Generate default slice config for passthrough merges."""
slices = []
for model_id in config.models:
slices.append({
"sources": [{"model": model_id, "layer_range": [0, 32]}]
})
return slices
def generate_from_preset(
preset_name: str,
model_ids: list[str],
base_model: str = "",
tokenizer_source: str = "",
dtype: str = "bfloat16",
) -> str:
"""Quick config generation from a preset name.
Args:
preset_name: Key from PRESETS dict
model_ids: List of model IDs to merge
base_model: Base model for methods that need one
tokenizer_source: Which model's tokenizer to use
dtype: Data type for merge
Returns:
YAML string
"""
preset = PRESETS.get(preset_name)
if not preset:
return f"# Unknown preset: {preset_name}\n# Available: {', '.join(PRESETS.keys())}"
weights, densities = preset.apply(model_ids)
config = MergeConfig(
method=preset.method,
models=model_ids,
base_model=base_model or (model_ids[0] if model_ids else ""),
weights=weights,
densities=densities,
tokenizer_source=tokenizer_source or base_model or (model_ids[0] if model_ids else ""),
dtype=dtype,
)
return generate_yaml(config)
def get_method_info(method: str) -> dict:
"""Get human-readable info about a merge method."""
return MERGE_METHODS.get(method, {"name": "Unknown", "description": "Unknown method"})