|
|
"""Architecture compatibility checker for model merging.""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
from .model_info import ModelInfo, fetch_model_info |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CompatibilityReport: |
|
|
"""Result of compatibility checking between models.""" |
|
|
compatible: bool = True |
|
|
errors: list[str] = field(default_factory=list) |
|
|
warnings: list[str] = field(default_factory=list) |
|
|
suggestions: list[str] = field(default_factory=list) |
|
|
models_info: list[ModelInfo] = field(default_factory=list) |
|
|
suggested_base: str = "" |
|
|
suggested_tokenizer: str = "" |
|
|
architecture: str = "" |
|
|
merge_methods_available: list[str] = field(default_factory=list) |
|
|
estimated_ram_gb: float = 0.0 |
|
|
estimated_merge_time: str = "" |
|
|
|
|
|
@property |
|
|
def status_emoji(self) -> str: |
|
|
if not self.compatible: |
|
|
return "❌" |
|
|
elif self.warnings: |
|
|
return "⚠️" |
|
|
return "✅" |
|
|
|
|
|
@property |
|
|
def status_text(self) -> str: |
|
|
if not self.compatible: |
|
|
return "Incompatible — cannot merge" |
|
|
elif self.warnings: |
|
|
return "Compatible with warnings" |
|
|
return "Fully compatible" |
|
|
|
|
|
def to_markdown(self) -> str: |
|
|
"""Generate a formatted markdown report.""" |
|
|
lines = [] |
|
|
|
|
|
|
|
|
lines.append(f"## {self.status_emoji} Compatibility Report") |
|
|
lines.append("") |
|
|
|
|
|
if self.architecture: |
|
|
lines.append(f"**Architecture:** `{self.architecture}`") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if self.errors: |
|
|
lines.append("### ❌ Errors") |
|
|
for e in self.errors: |
|
|
lines.append(f"- {e}") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if self.warnings: |
|
|
lines.append("### ⚠️ Warnings") |
|
|
for w in self.warnings: |
|
|
lines.append(f"- {w}") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if self.models_info: |
|
|
lines.append("### Model Details") |
|
|
lines.append("| Model | Type | Hidden | Layers | Vocab | Params |") |
|
|
lines.append("|-------|------|--------|--------|-------|--------|") |
|
|
for m in self.models_info: |
|
|
name = m.display_name |
|
|
if len(name) > 35: |
|
|
name = name[:32] + "..." |
|
|
lines.append( |
|
|
f"| {name} | `{m.model_type}` | {m.hidden_size} | " |
|
|
f"{m.num_hidden_layers} | {m.vocab_size} | {m.param_estimate} |" |
|
|
) |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if self.suggestions: |
|
|
lines.append("### 💡 Suggestions") |
|
|
for s in self.suggestions: |
|
|
lines.append(f"- {s}") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if self.merge_methods_available: |
|
|
methods = ", ".join(f"`{m}`" for m in self.merge_methods_available) |
|
|
lines.append(f"**Available merge methods:** {methods}") |
|
|
lines.append("") |
|
|
|
|
|
|
|
|
if self.estimated_ram_gb > 0: |
|
|
lines.append(f"**Estimated RAM:** {self.estimated_ram_gb} GB") |
|
|
lines.append(f"**Estimated time:** {self.estimated_merge_time}") |
|
|
colab_tier = "Standard" if self.estimated_ram_gb <= 12 else "High-RAM" if self.estimated_ram_gb <= 48 else "A100 (Colab Pro+)" |
|
|
lines.append(f"**Recommended Colab:** {colab_tier}") |
|
|
lines.append("") |
|
|
|
|
|
if self.suggested_base: |
|
|
lines.append(f"**Suggested base model:** `{self.suggested_base}`") |
|
|
if self.suggested_tokenizer: |
|
|
lines.append(f"**Suggested tokenizer:** `{self.suggested_tokenizer}`") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def check_compatibility( |
|
|
model_ids: list[str], |
|
|
token: Optional[str] = None, |
|
|
) -> CompatibilityReport: |
|
|
"""Check if a list of models are compatible for merging. |
|
|
|
|
|
Args: |
|
|
model_ids: List of HuggingFace model IDs |
|
|
token: Optional HF API token for gated models |
|
|
|
|
|
Returns: |
|
|
CompatibilityReport with detailed analysis |
|
|
""" |
|
|
report = CompatibilityReport() |
|
|
|
|
|
|
|
|
if len(model_ids) < 2: |
|
|
report.compatible = False |
|
|
report.errors.append("At least 2 models are required for merging.") |
|
|
return report |
|
|
|
|
|
if len(model_ids) > 10: |
|
|
report.warnings.append("Merging more than 10 models is unusual and may produce poor results.") |
|
|
|
|
|
|
|
|
for mid in model_ids: |
|
|
mid = mid.strip() |
|
|
if not mid: |
|
|
continue |
|
|
info = fetch_model_info(mid, token=token) |
|
|
report.models_info.append(info) |
|
|
|
|
|
if info.error: |
|
|
if info.gated: |
|
|
report.warnings.append(f"`{mid}`: Gated model — provide HF token to verify compatibility") |
|
|
else: |
|
|
report.compatible = False |
|
|
report.errors.append(f"`{mid}`: {info.error}") |
|
|
|
|
|
|
|
|
valid_models = [m for m in report.models_info if not m.error] |
|
|
if len(valid_models) < 2: |
|
|
report.compatible = False |
|
|
if not report.errors: |
|
|
report.errors.append("Could not fetch enough model configs to verify compatibility.") |
|
|
return report |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
types = set(m.model_type for m in valid_models) |
|
|
if len(types) > 1: |
|
|
report.compatible = False |
|
|
report.errors.append( |
|
|
f"Architecture mismatch! Found: {', '.join(f'`{t}`' for t in types)}. " |
|
|
f"All models must share the same architecture to merge." |
|
|
) |
|
|
return report |
|
|
|
|
|
report.architecture = valid_models[0].model_type |
|
|
|
|
|
|
|
|
hidden_sizes = set(m.hidden_size for m in valid_models if m.hidden_size > 0) |
|
|
if len(hidden_sizes) > 1: |
|
|
report.compatible = False |
|
|
report.errors.append( |
|
|
f"Hidden size mismatch: {', '.join(str(s) for s in hidden_sizes)}. " |
|
|
f"Models must have the same hidden dimension." |
|
|
) |
|
|
|
|
|
|
|
|
inter_sizes = set(m.intermediate_size for m in valid_models if m.intermediate_size > 0) |
|
|
if len(inter_sizes) > 1: |
|
|
report.compatible = False |
|
|
report.errors.append( |
|
|
f"Intermediate size mismatch: {', '.join(str(s) for s in inter_sizes)}. " |
|
|
f"Required for DARE-TIES, SLERP, and Linear methods." |
|
|
) |
|
|
|
|
|
|
|
|
layer_counts = set(m.num_hidden_layers for m in valid_models if m.num_hidden_layers > 0) |
|
|
if len(layer_counts) > 1: |
|
|
report.warnings.append( |
|
|
f"Layer count differs: {', '.join(str(l) for l in layer_counts)}. " |
|
|
f"Passthrough/Frankenmerge can handle this, but DARE-TIES/SLERP/Linear require matching layers." |
|
|
) |
|
|
|
|
|
|
|
|
vocab_sizes = set(m.vocab_size for m in valid_models if m.vocab_size > 0) |
|
|
if len(vocab_sizes) > 1: |
|
|
report.warnings.append( |
|
|
f"Vocabulary size differs: {', '.join(str(v) for v in vocab_sizes)}. " |
|
|
f"Use `tokenizer_source` to specify which tokenizer to keep." |
|
|
) |
|
|
|
|
|
|
|
|
head_counts = set(m.num_attention_heads for m in valid_models if m.num_attention_heads > 0) |
|
|
kv_head_counts = set(m.num_key_value_heads for m in valid_models if m.num_key_value_heads > 0) |
|
|
if len(head_counts) > 1: |
|
|
report.compatible = False |
|
|
report.errors.append( |
|
|
f"Attention head count mismatch: {', '.join(str(h) for h in head_counts)}." |
|
|
) |
|
|
if len(kv_head_counts) > 1: |
|
|
report.warnings.append( |
|
|
f"KV head count differs: {', '.join(str(h) for h in kv_head_counts)}. " |
|
|
f"This may cause issues with GQA models." |
|
|
) |
|
|
|
|
|
|
|
|
needs_trust = [m.model_id for m in valid_models if m.trust_remote_code] |
|
|
if needs_trust: |
|
|
report.warnings.append( |
|
|
f"Models requiring `trust_remote_code=True`: " |
|
|
f"{', '.join(f'`{m}`' for m in needs_trust)}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if valid_models: |
|
|
|
|
|
base_candidates = sorted( |
|
|
valid_models, |
|
|
key=lambda m: ( |
|
|
"instruct" in m.model_id.lower() and "code" not in m.model_id.lower(), |
|
|
-m.downloads, |
|
|
), |
|
|
) |
|
|
report.suggested_base = base_candidates[0].model_id |
|
|
report.suggestions.append(f"Use `{report.suggested_base}` as the base model") |
|
|
|
|
|
|
|
|
if vocab_sizes and len(vocab_sizes) > 1: |
|
|
largest_vocab_model = max(valid_models, key=lambda m: m.vocab_size) |
|
|
report.suggested_tokenizer = largest_vocab_model.model_id |
|
|
report.suggestions.append( |
|
|
f"Use tokenizer from `{report.suggested_tokenizer}` (largest vocab: {largest_vocab_model.vocab_size})" |
|
|
) |
|
|
elif valid_models: |
|
|
report.suggested_tokenizer = report.suggested_base |
|
|
|
|
|
|
|
|
n = len(valid_models) |
|
|
methods = [] |
|
|
|
|
|
if report.compatible: |
|
|
|
|
|
methods.append("linear") |
|
|
|
|
|
|
|
|
if len(layer_counts) <= 1: |
|
|
methods.append("dare_ties") |
|
|
methods.append("ties") |
|
|
|
|
|
|
|
|
if n == 2 and len(layer_counts) <= 1: |
|
|
methods.append("slerp") |
|
|
|
|
|
|
|
|
methods.append("task_arithmetic") |
|
|
|
|
|
|
|
|
methods.append("passthrough") |
|
|
|
|
|
report.merge_methods_available = methods |
|
|
|
|
|
|
|
|
max_size = max((m.size_bytes for m in valid_models if m.size_bytes > 0), default=0) |
|
|
if max_size > 0: |
|
|
|
|
|
total_model_bytes = sum(m.size_bytes for m in valid_models if m.size_bytes > 0) |
|
|
|
|
|
ram_needed = (total_model_bytes + max_size) * 1.3 |
|
|
report.estimated_ram_gb = round(ram_needed / (1024**3), 1) |
|
|
|
|
|
|
|
|
total_gb = total_model_bytes / (1024**3) |
|
|
if total_gb < 10: |
|
|
report.estimated_merge_time = "5-15 minutes" |
|
|
elif total_gb < 30: |
|
|
report.estimated_merge_time = "15-30 minutes" |
|
|
elif total_gb < 60: |
|
|
report.estimated_merge_time = "30-60 minutes" |
|
|
else: |
|
|
report.estimated_merge_time = "1-2+ hours" |
|
|
|
|
|
return report |
|
|
|
|
|
|
|
|
def quick_check(model_ids: list[str], token: Optional[str] = None) -> str: |
|
|
"""Quick one-line compatibility check. |
|
|
|
|
|
Returns a formatted string like: |
|
|
"✅ Compatible (qwen2) | 3 models | ~32GB RAM | DARE-TIES, SLERP, Linear" |
|
|
""" |
|
|
report = check_compatibility(model_ids, token=token) |
|
|
|
|
|
if not report.compatible: |
|
|
errors = "; ".join(report.errors[:2]) |
|
|
return f"❌ {errors}" |
|
|
|
|
|
methods = ", ".join(report.merge_methods_available[:3]) |
|
|
parts = [ |
|
|
f"{report.status_emoji} {report.status_text}", |
|
|
f"Architecture: {report.architecture}", |
|
|
f"{len(report.models_info)} models", |
|
|
] |
|
|
if report.estimated_ram_gb > 0: |
|
|
parts.append(f"~{report.estimated_ram_gb}GB RAM") |
|
|
parts.append(f"Methods: {methods}") |
|
|
|
|
|
return " | ".join(parts) |
|
|
|