"""Architecture compatibility checker for model merging.""" from dataclasses import dataclass, field from typing import Optional from .model_info import ModelInfo, fetch_model_info @dataclass class CompatibilityReport: """Result of compatibility checking between models.""" compatible: bool = True errors: list[str] = field(default_factory=list) warnings: list[str] = field(default_factory=list) suggestions: list[str] = field(default_factory=list) models_info: list[ModelInfo] = field(default_factory=list) suggested_base: str = "" suggested_tokenizer: str = "" architecture: str = "" merge_methods_available: list[str] = field(default_factory=list) estimated_ram_gb: float = 0.0 estimated_merge_time: str = "" @property def status_emoji(self) -> str: if not self.compatible: return "❌" elif self.warnings: return "⚠️" return "✅" @property def status_text(self) -> str: if not self.compatible: return "Incompatible — cannot merge" elif self.warnings: return "Compatible with warnings" return "Fully compatible" def to_markdown(self) -> str: """Generate a formatted markdown report.""" lines = [] # Header lines.append(f"## {self.status_emoji} Compatibility Report") lines.append("") if self.architecture: lines.append(f"**Architecture:** `{self.architecture}`") lines.append("") # Errors if self.errors: lines.append("### ❌ Errors") for e in self.errors: lines.append(f"- {e}") lines.append("") # Warnings if self.warnings: lines.append("### ⚠️ Warnings") for w in self.warnings: lines.append(f"- {w}") lines.append("") # Model details table if self.models_info: lines.append("### Model Details") lines.append("| Model | Type | Hidden | Layers | Vocab | Params |") lines.append("|-------|------|--------|--------|-------|--------|") for m in self.models_info: name = m.display_name if len(name) > 35: name = name[:32] + "..." lines.append( f"| {name} | `{m.model_type}` | {m.hidden_size} | " f"{m.num_hidden_layers} | {m.vocab_size} | {m.param_estimate} |" ) lines.append("") # Suggestions if self.suggestions: lines.append("### 💡 Suggestions") for s in self.suggestions: lines.append(f"- {s}") lines.append("") # Merge methods if self.merge_methods_available: methods = ", ".join(f"`{m}`" for m in self.merge_methods_available) lines.append(f"**Available merge methods:** {methods}") lines.append("") # Resource estimates if self.estimated_ram_gb > 0: lines.append(f"**Estimated RAM:** {self.estimated_ram_gb} GB") lines.append(f"**Estimated time:** {self.estimated_merge_time}") colab_tier = "Standard" if self.estimated_ram_gb <= 12 else "High-RAM" if self.estimated_ram_gb <= 48 else "A100 (Colab Pro+)" lines.append(f"**Recommended Colab:** {colab_tier}") lines.append("") if self.suggested_base: lines.append(f"**Suggested base model:** `{self.suggested_base}`") if self.suggested_tokenizer: lines.append(f"**Suggested tokenizer:** `{self.suggested_tokenizer}`") return "\n".join(lines) def check_compatibility( model_ids: list[str], token: Optional[str] = None, ) -> CompatibilityReport: """Check if a list of models are compatible for merging. Args: model_ids: List of HuggingFace model IDs token: Optional HF API token for gated models Returns: CompatibilityReport with detailed analysis """ report = CompatibilityReport() # Validate input if len(model_ids) < 2: report.compatible = False report.errors.append("At least 2 models are required for merging.") return report if len(model_ids) > 10: report.warnings.append("Merging more than 10 models is unusual and may produce poor results.") # Fetch all model info for mid in model_ids: mid = mid.strip() if not mid: continue info = fetch_model_info(mid, token=token) report.models_info.append(info) if info.error: if info.gated: report.warnings.append(f"`{mid}`: Gated model — provide HF token to verify compatibility") else: report.compatible = False report.errors.append(f"`{mid}`: {info.error}") # If we couldn't fetch any models, bail valid_models = [m for m in report.models_info if not m.error] if len(valid_models) < 2: report.compatible = False if not report.errors: report.errors.append("Could not fetch enough model configs to verify compatibility.") return report # === ARCHITECTURE CHECKS === # 1. model_type must match types = set(m.model_type for m in valid_models) if len(types) > 1: report.compatible = False report.errors.append( f"Architecture mismatch! Found: {', '.join(f'`{t}`' for t in types)}. " f"All models must share the same architecture to merge." ) return report report.architecture = valid_models[0].model_type # 2. hidden_size must match hidden_sizes = set(m.hidden_size for m in valid_models if m.hidden_size > 0) if len(hidden_sizes) > 1: report.compatible = False report.errors.append( f"Hidden size mismatch: {', '.join(str(s) for s in hidden_sizes)}. " f"Models must have the same hidden dimension." ) # 3. intermediate_size must match (for most methods) inter_sizes = set(m.intermediate_size for m in valid_models if m.intermediate_size > 0) if len(inter_sizes) > 1: report.compatible = False report.errors.append( f"Intermediate size mismatch: {', '.join(str(s) for s in inter_sizes)}. " f"Required for DARE-TIES, SLERP, and Linear methods." ) # 4. num_hidden_layers — warn if different layer_counts = set(m.num_hidden_layers for m in valid_models if m.num_hidden_layers > 0) if len(layer_counts) > 1: report.warnings.append( f"Layer count differs: {', '.join(str(l) for l in layer_counts)}. " f"Passthrough/Frankenmerge can handle this, but DARE-TIES/SLERP/Linear require matching layers." ) # 5. vocab_size — warn if different vocab_sizes = set(m.vocab_size for m in valid_models if m.vocab_size > 0) if len(vocab_sizes) > 1: report.warnings.append( f"Vocabulary size differs: {', '.join(str(v) for v in vocab_sizes)}. " f"Use `tokenizer_source` to specify which tokenizer to keep." ) # 6. num_attention_heads / num_key_value_heads head_counts = set(m.num_attention_heads for m in valid_models if m.num_attention_heads > 0) kv_head_counts = set(m.num_key_value_heads for m in valid_models if m.num_key_value_heads > 0) if len(head_counts) > 1: report.compatible = False report.errors.append( f"Attention head count mismatch: {', '.join(str(h) for h in head_counts)}." ) if len(kv_head_counts) > 1: report.warnings.append( f"KV head count differs: {', '.join(str(h) for h in kv_head_counts)}. " f"This may cause issues with GQA models." ) # 7. trust_remote_code warning needs_trust = [m.model_id for m in valid_models if m.trust_remote_code] if needs_trust: report.warnings.append( f"Models requiring `trust_remote_code=True`: " f"{', '.join(f'`{m}`' for m in needs_trust)}" ) # === SUGGESTIONS === # Suggest base model (most downloaded or original base if detectable) if valid_models: # Prefer instruct/base versions, then most downloaded base_candidates = sorted( valid_models, key=lambda m: ( "instruct" in m.model_id.lower() and "code" not in m.model_id.lower(), -m.downloads, ), ) report.suggested_base = base_candidates[0].model_id report.suggestions.append(f"Use `{report.suggested_base}` as the base model") # Suggest tokenizer source (largest vocab) if vocab_sizes and len(vocab_sizes) > 1: largest_vocab_model = max(valid_models, key=lambda m: m.vocab_size) report.suggested_tokenizer = largest_vocab_model.model_id report.suggestions.append( f"Use tokenizer from `{report.suggested_tokenizer}` (largest vocab: {largest_vocab_model.vocab_size})" ) elif valid_models: report.suggested_tokenizer = report.suggested_base # === AVAILABLE MERGE METHODS === n = len(valid_models) methods = [] if report.compatible: # Linear always works if architectures match methods.append("linear") # DARE-TIES needs matching layers if len(layer_counts) <= 1: methods.append("dare_ties") methods.append("ties") # SLERP only for 2 models if n == 2 and len(layer_counts) <= 1: methods.append("slerp") # Task arithmetic needs a base methods.append("task_arithmetic") # Passthrough works even with different layer counts methods.append("passthrough") report.merge_methods_available = methods # === RESOURCE ESTIMATES === max_size = max((m.size_bytes for m in valid_models if m.size_bytes > 0), default=0) if max_size > 0: # Merging needs roughly: all models loaded + output total_model_bytes = sum(m.size_bytes for m in valid_models if m.size_bytes > 0) # Rule of thumb: need models + 50% overhead ram_needed = (total_model_bytes + max_size) * 1.3 report.estimated_ram_gb = round(ram_needed / (1024**3), 1) # Time estimate based on total size total_gb = total_model_bytes / (1024**3) if total_gb < 10: report.estimated_merge_time = "5-15 minutes" elif total_gb < 30: report.estimated_merge_time = "15-30 minutes" elif total_gb < 60: report.estimated_merge_time = "30-60 minutes" else: report.estimated_merge_time = "1-2+ hours" return report def quick_check(model_ids: list[str], token: Optional[str] = None) -> str: """Quick one-line compatibility check. Returns a formatted string like: "✅ Compatible (qwen2) | 3 models | ~32GB RAM | DARE-TIES, SLERP, Linear" """ report = check_compatibility(model_ids, token=token) if not report.compatible: errors = "; ".join(report.errors[:2]) return f"❌ {errors}" methods = ", ".join(report.merge_methods_available[:3]) parts = [ f"{report.status_emoji} {report.status_text}", f"Architecture: {report.architecture}", f"{len(report.models_info)} models", ] if report.estimated_ram_gb > 0: parts.append(f"~{report.estimated_ram_gb}GB RAM") parts.append(f"Methods: {methods}") return " | ".join(parts)