|
|
"""Merge configuration YAML generator with presets and validation.""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
import yaml |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MERGE_METHODS = { |
|
|
"dare_ties": { |
|
|
"name": "DARE-TIES", |
|
|
"description": "Drop And REscale with TIES — trims low-magnitude parameters and resolves sign conflicts. Best for combining 2+ specialist models.", |
|
|
"min_models": 2, |
|
|
"max_models": 10, |
|
|
"needs_base": True, |
|
|
"params": ["weight", "density"], |
|
|
"global_params": ["int8_mask", "normalize"], |
|
|
"supports_slices": True, |
|
|
}, |
|
|
"ties": { |
|
|
"name": "TIES", |
|
|
"description": "Trim, Elect Sign, Merge — resolves parameter interference between models. Similar to DARE-TIES but without the drop step.", |
|
|
"min_models": 2, |
|
|
"max_models": 10, |
|
|
"needs_base": True, |
|
|
"params": ["weight", "density"], |
|
|
"global_params": ["int8_mask", "normalize"], |
|
|
"supports_slices": True, |
|
|
}, |
|
|
"slerp": { |
|
|
"name": "SLERP", |
|
|
"description": "Spherical Linear Interpolation — smoothly blends two models along a curved path in weight space. Best for two-model merges.", |
|
|
"min_models": 2, |
|
|
"max_models": 2, |
|
|
"needs_base": False, |
|
|
"params": [], |
|
|
"global_params": ["t"], |
|
|
"supports_slices": True, |
|
|
}, |
|
|
"linear": { |
|
|
"name": "Linear", |
|
|
"description": "Simple weighted average of model parameters. Fast and predictable baseline.", |
|
|
"min_models": 2, |
|
|
"max_models": 10, |
|
|
"needs_base": False, |
|
|
"params": ["weight"], |
|
|
"global_params": ["normalize"], |
|
|
"supports_slices": True, |
|
|
}, |
|
|
"task_arithmetic": { |
|
|
"name": "Task Arithmetic", |
|
|
"description": "Add or subtract task vectors from a base model. Use negative weights to remove capabilities.", |
|
|
"min_models": 1, |
|
|
"max_models": 10, |
|
|
"needs_base": True, |
|
|
"params": ["weight"], |
|
|
"global_params": [], |
|
|
"supports_slices": False, |
|
|
}, |
|
|
"passthrough": { |
|
|
"name": "Passthrough (Frankenmerge)", |
|
|
"description": "Stack layers from different models. Can create larger models from smaller ones. Supports different layer counts.", |
|
|
"min_models": 1, |
|
|
"max_models": 10, |
|
|
"needs_base": False, |
|
|
"params": [], |
|
|
"global_params": [], |
|
|
"supports_slices": True, |
|
|
"requires_slices": True, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MergePreset: |
|
|
name: str |
|
|
description: str |
|
|
method: str |
|
|
weight_strategy: str |
|
|
|
|
|
def apply(self, model_ids: list[str]) -> tuple[list[float], list[float]]: |
|
|
"""Generate weights and densities for given models.""" |
|
|
n = len(model_ids) |
|
|
if n == 0: |
|
|
return [], [] |
|
|
|
|
|
if self.weight_strategy == "equal": |
|
|
weights = [round(1.0 / n, 3)] * n |
|
|
densities = [0.6] * n |
|
|
|
|
|
elif self.weight_strategy == "first_dominant": |
|
|
weights = [0.6] + [round(0.4 / (n - 1), 3)] * (n - 1) if n > 1 else [1.0] |
|
|
densities = [0.7] + [0.5] * (n - 1) |
|
|
|
|
|
elif self.weight_strategy == "last_dominant": |
|
|
weights = [round(0.4 / (n - 1), 3)] * (n - 1) + [0.6] if n > 1 else [1.0] |
|
|
densities = [0.5] * (n - 1) + [0.7] |
|
|
|
|
|
elif self.weight_strategy == "auto_detect": |
|
|
weights, densities = _auto_detect_weights(model_ids) |
|
|
|
|
|
else: |
|
|
weights = [round(1.0 / n, 3)] * n |
|
|
densities = [0.6] * n |
|
|
|
|
|
return weights, densities |
|
|
|
|
|
|
|
|
def _auto_detect_weights(model_ids: list[str]) -> tuple[list[float], list[float]]: |
|
|
"""Auto-detect optimal weights based on model names/tags.""" |
|
|
n = len(model_ids) |
|
|
weights = [] |
|
|
densities = [] |
|
|
|
|
|
for mid in model_ids: |
|
|
name = mid.lower() |
|
|
if "code" in name or "coder" in name: |
|
|
weights.append(0.5) |
|
|
densities.append(0.7) |
|
|
elif "math" in name: |
|
|
weights.append(0.4) |
|
|
densities.append(0.6) |
|
|
elif "instruct" in name and "code" not in name: |
|
|
weights.append(0.3) |
|
|
densities.append(0.5) |
|
|
else: |
|
|
weights.append(0.3) |
|
|
densities.append(0.5) |
|
|
|
|
|
|
|
|
total = sum(weights) |
|
|
if total > 0: |
|
|
weights = [round(w / total, 3) for w in weights] |
|
|
|
|
|
return weights, densities |
|
|
|
|
|
|
|
|
PRESETS = { |
|
|
"equal": MergePreset("Equal", "Equal weights for all models", "dare_ties", "equal"), |
|
|
"first_dominant": MergePreset("First Model Dominant", "Prioritize the first model", "dare_ties", "first_dominant"), |
|
|
"last_dominant": MergePreset("Last Model Dominant", "Prioritize the last model", "dare_ties", "last_dominant"), |
|
|
"coding_focus": MergePreset("Coding Focus", "Higher weight for code-related models", "dare_ties", "auto_detect"), |
|
|
"balanced_slerp": MergePreset("Balanced SLERP", "50/50 interpolation between two models", "slerp", "equal"), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MergeConfig: |
|
|
"""Complete merge configuration.""" |
|
|
method: str = "dare_ties" |
|
|
models: list[str] = field(default_factory=list) |
|
|
base_model: str = "" |
|
|
weights: list[float] = field(default_factory=list) |
|
|
densities: list[float] = field(default_factory=list) |
|
|
tokenizer_source: str = "" |
|
|
dtype: str = "bfloat16" |
|
|
|
|
|
|
|
|
slerp_t: float = 0.5 |
|
|
int8_mask: bool = True |
|
|
normalize: bool = True |
|
|
|
|
|
|
|
|
slices: list[dict] = field(default_factory=list) |
|
|
|
|
|
|
|
|
output_name: str = "" |
|
|
|
|
|
def validate(self) -> list[str]: |
|
|
"""Validate the configuration. Returns list of error messages.""" |
|
|
errors = [] |
|
|
method_info = MERGE_METHODS.get(self.method) |
|
|
|
|
|
if not method_info: |
|
|
errors.append(f"Unknown merge method: {self.method}") |
|
|
return errors |
|
|
|
|
|
n = len(self.models) |
|
|
if n < method_info["min_models"]: |
|
|
errors.append(f"{method_info['name']} requires at least {method_info['min_models']} models") |
|
|
if n > method_info["max_models"]: |
|
|
errors.append(f"{method_info['name']} supports at most {method_info['max_models']} models") |
|
|
|
|
|
if method_info["needs_base"] and not self.base_model: |
|
|
errors.append(f"{method_info['name']} requires a base_model") |
|
|
|
|
|
if "weight" in method_info["params"]: |
|
|
if self.weights and len(self.weights) != n: |
|
|
errors.append(f"Expected {n} weights, got {len(self.weights)}") |
|
|
if self.weights and any(w < -1 or w > 2 for w in self.weights): |
|
|
errors.append("Weights should be between -1 and 2") |
|
|
|
|
|
if "density" in method_info["params"]: |
|
|
if self.densities and len(self.densities) != n: |
|
|
errors.append(f"Expected {n} densities, got {len(self.densities)}") |
|
|
if self.densities and any(d < 0 or d > 1 for d in self.densities): |
|
|
errors.append("Densities must be between 0 and 1") |
|
|
|
|
|
if self.method == "slerp" and (self.slerp_t < 0 or self.slerp_t > 1): |
|
|
errors.append("SLERP t parameter must be between 0 and 1") |
|
|
|
|
|
if method_info.get("requires_slices") and not self.slices: |
|
|
errors.append(f"{method_info['name']} requires slice definitions") |
|
|
|
|
|
return errors |
|
|
|
|
|
|
|
|
def generate_yaml(config: MergeConfig) -> str: |
|
|
"""Generate mergekit-compatible YAML configuration. |
|
|
|
|
|
Args: |
|
|
config: MergeConfig with all parameters |
|
|
|
|
|
Returns: |
|
|
YAML string ready for mergekit |
|
|
""" |
|
|
errors = config.validate() |
|
|
if errors: |
|
|
return f"# VALIDATION ERRORS:\n" + "\n".join(f"# - {e}" for e in errors) |
|
|
|
|
|
method_info = MERGE_METHODS[config.method] |
|
|
doc = {} |
|
|
|
|
|
|
|
|
if config.method == "passthrough": |
|
|
doc["slices"] = config.slices or _default_slices(config) |
|
|
doc["merge_method"] = config.method |
|
|
doc["dtype"] = config.dtype |
|
|
return yaml.dump(doc, default_flow_style=False, sort_keys=False) |
|
|
|
|
|
|
|
|
doc["merge_method"] = config.method |
|
|
|
|
|
if method_info["needs_base"]: |
|
|
doc["base_model"] = config.base_model |
|
|
|
|
|
|
|
|
if config.method == "slerp": |
|
|
doc["models"] = [{"model": m} for m in config.models] |
|
|
doc["parameters"] = {"t": config.slerp_t} |
|
|
else: |
|
|
models_list = [] |
|
|
for i, model_id in enumerate(config.models): |
|
|
entry = {"model": model_id} |
|
|
params = {} |
|
|
if "weight" in method_info["params"] and config.weights: |
|
|
params["weight"] = config.weights[i] |
|
|
if "density" in method_info["params"] and config.densities: |
|
|
params["density"] = config.densities[i] |
|
|
if params: |
|
|
entry["parameters"] = params |
|
|
models_list.append(entry) |
|
|
doc["models"] = models_list |
|
|
|
|
|
|
|
|
global_params = {} |
|
|
if "int8_mask" in method_info.get("global_params", []): |
|
|
global_params["int8_mask"] = config.int8_mask |
|
|
if "normalize" in method_info.get("global_params", []): |
|
|
global_params["normalize"] = config.normalize |
|
|
|
|
|
if global_params: |
|
|
doc["parameters"] = global_params |
|
|
|
|
|
doc["dtype"] = config.dtype |
|
|
|
|
|
if config.tokenizer_source: |
|
|
doc["tokenizer_source"] = config.tokenizer_source |
|
|
|
|
|
return yaml.dump(doc, default_flow_style=False, sort_keys=False) |
|
|
|
|
|
|
|
|
def _default_slices(config: MergeConfig) -> list[dict]: |
|
|
"""Generate default slice config for passthrough merges.""" |
|
|
slices = [] |
|
|
for model_id in config.models: |
|
|
slices.append({ |
|
|
"sources": [{"model": model_id, "layer_range": [0, 32]}] |
|
|
}) |
|
|
return slices |
|
|
|
|
|
|
|
|
def generate_from_preset( |
|
|
preset_name: str, |
|
|
model_ids: list[str], |
|
|
base_model: str = "", |
|
|
tokenizer_source: str = "", |
|
|
dtype: str = "bfloat16", |
|
|
) -> str: |
|
|
"""Quick config generation from a preset name. |
|
|
|
|
|
Args: |
|
|
preset_name: Key from PRESETS dict |
|
|
model_ids: List of model IDs to merge |
|
|
base_model: Base model for methods that need one |
|
|
tokenizer_source: Which model's tokenizer to use |
|
|
dtype: Data type for merge |
|
|
|
|
|
Returns: |
|
|
YAML string |
|
|
""" |
|
|
preset = PRESETS.get(preset_name) |
|
|
if not preset: |
|
|
return f"# Unknown preset: {preset_name}\n# Available: {', '.join(PRESETS.keys())}" |
|
|
|
|
|
weights, densities = preset.apply(model_ids) |
|
|
|
|
|
config = MergeConfig( |
|
|
method=preset.method, |
|
|
models=model_ids, |
|
|
base_model=base_model or (model_ids[0] if model_ids else ""), |
|
|
weights=weights, |
|
|
densities=densities, |
|
|
tokenizer_source=tokenizer_source or base_model or (model_ids[0] if model_ids else ""), |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
return generate_yaml(config) |
|
|
|
|
|
|
|
|
def get_method_info(method: str) -> dict: |
|
|
"""Get human-readable info about a merge method.""" |
|
|
return MERGE_METHODS.get(method, {"name": "Unknown", "description": "Unknown method"}) |
|
|
|