"""Merge configuration YAML generator with presets and validation.""" from dataclasses import dataclass, field from typing import Optional import yaml # ===== MERGE METHOD DEFINITIONS ===== MERGE_METHODS = { "dare_ties": { "name": "DARE-TIES", "description": "Drop And REscale with TIES — trims low-magnitude parameters and resolves sign conflicts. Best for combining 2+ specialist models.", "min_models": 2, "max_models": 10, "needs_base": True, "params": ["weight", "density"], "global_params": ["int8_mask", "normalize"], "supports_slices": True, }, "ties": { "name": "TIES", "description": "Trim, Elect Sign, Merge — resolves parameter interference between models. Similar to DARE-TIES but without the drop step.", "min_models": 2, "max_models": 10, "needs_base": True, "params": ["weight", "density"], "global_params": ["int8_mask", "normalize"], "supports_slices": True, }, "slerp": { "name": "SLERP", "description": "Spherical Linear Interpolation — smoothly blends two models along a curved path in weight space. Best for two-model merges.", "min_models": 2, "max_models": 2, "needs_base": False, "params": [], "global_params": ["t"], "supports_slices": True, }, "linear": { "name": "Linear", "description": "Simple weighted average of model parameters. Fast and predictable baseline.", "min_models": 2, "max_models": 10, "needs_base": False, "params": ["weight"], "global_params": ["normalize"], "supports_slices": True, }, "task_arithmetic": { "name": "Task Arithmetic", "description": "Add or subtract task vectors from a base model. Use negative weights to remove capabilities.", "min_models": 1, "max_models": 10, "needs_base": True, "params": ["weight"], "global_params": [], "supports_slices": False, }, "passthrough": { "name": "Passthrough (Frankenmerge)", "description": "Stack layers from different models. Can create larger models from smaller ones. Supports different layer counts.", "min_models": 1, "max_models": 10, "needs_base": False, "params": [], "global_params": [], "supports_slices": True, "requires_slices": True, }, } # ===== PRESETS ===== @dataclass class MergePreset: name: str description: str method: str weight_strategy: str # "equal", "first_dominant", "last_dominant", "auto_detect" def apply(self, model_ids: list[str]) -> tuple[list[float], list[float]]: """Generate weights and densities for given models.""" n = len(model_ids) if n == 0: return [], [] if self.weight_strategy == "equal": weights = [round(1.0 / n, 3)] * n densities = [0.6] * n elif self.weight_strategy == "first_dominant": weights = [0.6] + [round(0.4 / (n - 1), 3)] * (n - 1) if n > 1 else [1.0] densities = [0.7] + [0.5] * (n - 1) elif self.weight_strategy == "last_dominant": weights = [round(0.4 / (n - 1), 3)] * (n - 1) + [0.6] if n > 1 else [1.0] densities = [0.5] * (n - 1) + [0.7] elif self.weight_strategy == "auto_detect": weights, densities = _auto_detect_weights(model_ids) else: weights = [round(1.0 / n, 3)] * n densities = [0.6] * n return weights, densities def _auto_detect_weights(model_ids: list[str]) -> tuple[list[float], list[float]]: """Auto-detect optimal weights based on model names/tags.""" n = len(model_ids) weights = [] densities = [] for mid in model_ids: name = mid.lower() if "code" in name or "coder" in name: weights.append(0.5) densities.append(0.7) elif "math" in name: weights.append(0.4) densities.append(0.6) elif "instruct" in name and "code" not in name: weights.append(0.3) densities.append(0.5) else: weights.append(0.3) densities.append(0.5) # Normalize weights to sum to 1 total = sum(weights) if total > 0: weights = [round(w / total, 3) for w in weights] return weights, densities PRESETS = { "equal": MergePreset("Equal", "Equal weights for all models", "dare_ties", "equal"), "first_dominant": MergePreset("First Model Dominant", "Prioritize the first model", "dare_ties", "first_dominant"), "last_dominant": MergePreset("Last Model Dominant", "Prioritize the last model", "dare_ties", "last_dominant"), "coding_focus": MergePreset("Coding Focus", "Higher weight for code-related models", "dare_ties", "auto_detect"), "balanced_slerp": MergePreset("Balanced SLERP", "50/50 interpolation between two models", "slerp", "equal"), } # ===== CONFIG GENERATION ===== @dataclass class MergeConfig: """Complete merge configuration.""" method: str = "dare_ties" models: list[str] = field(default_factory=list) base_model: str = "" weights: list[float] = field(default_factory=list) densities: list[float] = field(default_factory=list) tokenizer_source: str = "" dtype: str = "bfloat16" # Method-specific params slerp_t: float = 0.5 int8_mask: bool = True normalize: bool = True # Passthrough/slice params slices: list[dict] = field(default_factory=list) # Output output_name: str = "" def validate(self) -> list[str]: """Validate the configuration. Returns list of error messages.""" errors = [] method_info = MERGE_METHODS.get(self.method) if not method_info: errors.append(f"Unknown merge method: {self.method}") return errors n = len(self.models) if n < method_info["min_models"]: errors.append(f"{method_info['name']} requires at least {method_info['min_models']} models") if n > method_info["max_models"]: errors.append(f"{method_info['name']} supports at most {method_info['max_models']} models") if method_info["needs_base"] and not self.base_model: errors.append(f"{method_info['name']} requires a base_model") if "weight" in method_info["params"]: if self.weights and len(self.weights) != n: errors.append(f"Expected {n} weights, got {len(self.weights)}") if self.weights and any(w < -1 or w > 2 for w in self.weights): errors.append("Weights should be between -1 and 2") if "density" in method_info["params"]: if self.densities and len(self.densities) != n: errors.append(f"Expected {n} densities, got {len(self.densities)}") if self.densities and any(d < 0 or d > 1 for d in self.densities): errors.append("Densities must be between 0 and 1") if self.method == "slerp" and (self.slerp_t < 0 or self.slerp_t > 1): errors.append("SLERP t parameter must be between 0 and 1") if method_info.get("requires_slices") and not self.slices: errors.append(f"{method_info['name']} requires slice definitions") return errors def generate_yaml(config: MergeConfig) -> str: """Generate mergekit-compatible YAML configuration. Args: config: MergeConfig with all parameters Returns: YAML string ready for mergekit """ errors = config.validate() if errors: return f"# VALIDATION ERRORS:\n" + "\n".join(f"# - {e}" for e in errors) method_info = MERGE_METHODS[config.method] doc = {} # Passthrough uses slices format if config.method == "passthrough": doc["slices"] = config.slices or _default_slices(config) doc["merge_method"] = config.method doc["dtype"] = config.dtype return yaml.dump(doc, default_flow_style=False, sort_keys=False) # Standard methods doc["merge_method"] = config.method if method_info["needs_base"]: doc["base_model"] = config.base_model # Models with parameters if config.method == "slerp": doc["models"] = [{"model": m} for m in config.models] doc["parameters"] = {"t": config.slerp_t} else: models_list = [] for i, model_id in enumerate(config.models): entry = {"model": model_id} params = {} if "weight" in method_info["params"] and config.weights: params["weight"] = config.weights[i] if "density" in method_info["params"] and config.densities: params["density"] = config.densities[i] if params: entry["parameters"] = params models_list.append(entry) doc["models"] = models_list # Global parameters global_params = {} if "int8_mask" in method_info.get("global_params", []): global_params["int8_mask"] = config.int8_mask if "normalize" in method_info.get("global_params", []): global_params["normalize"] = config.normalize if global_params: doc["parameters"] = global_params doc["dtype"] = config.dtype if config.tokenizer_source: doc["tokenizer_source"] = config.tokenizer_source return yaml.dump(doc, default_flow_style=False, sort_keys=False) def _default_slices(config: MergeConfig) -> list[dict]: """Generate default slice config for passthrough merges.""" slices = [] for model_id in config.models: slices.append({ "sources": [{"model": model_id, "layer_range": [0, 32]}] }) return slices def generate_from_preset( preset_name: str, model_ids: list[str], base_model: str = "", tokenizer_source: str = "", dtype: str = "bfloat16", ) -> str: """Quick config generation from a preset name. Args: preset_name: Key from PRESETS dict model_ids: List of model IDs to merge base_model: Base model for methods that need one tokenizer_source: Which model's tokenizer to use dtype: Data type for merge Returns: YAML string """ preset = PRESETS.get(preset_name) if not preset: return f"# Unknown preset: {preset_name}\n# Available: {', '.join(PRESETS.keys())}" weights, densities = preset.apply(model_ids) config = MergeConfig( method=preset.method, models=model_ids, base_model=base_model or (model_ids[0] if model_ids else ""), weights=weights, densities=densities, tokenizer_source=tokenizer_source or base_model or (model_ids[0] if model_ids else ""), dtype=dtype, ) return generate_yaml(config) def get_method_info(method: str) -> dict: """Get human-readable info about a merge method.""" return MERGE_METHODS.get(method, {"name": "Unknown", "description": "Unknown method"})