Forgekit / forgekit /compatibility.py
AIencoder's picture
Rename compatibility.py to forgekit/compatibility.py
513f9ab verified
"""Architecture compatibility checker for model merging."""
from dataclasses import dataclass, field
from typing import Optional
from .model_info import ModelInfo, fetch_model_info
@dataclass
class CompatibilityReport:
"""Result of compatibility checking between models."""
compatible: bool = True
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
suggestions: list[str] = field(default_factory=list)
models_info: list[ModelInfo] = field(default_factory=list)
suggested_base: str = ""
suggested_tokenizer: str = ""
architecture: str = ""
merge_methods_available: list[str] = field(default_factory=list)
estimated_ram_gb: float = 0.0
estimated_merge_time: str = ""
@property
def status_emoji(self) -> str:
if not self.compatible:
return "❌"
elif self.warnings:
return "⚠️"
return "✅"
@property
def status_text(self) -> str:
if not self.compatible:
return "Incompatible — cannot merge"
elif self.warnings:
return "Compatible with warnings"
return "Fully compatible"
def to_markdown(self) -> str:
"""Generate a formatted markdown report."""
lines = []
# Header
lines.append(f"## {self.status_emoji} Compatibility Report")
lines.append("")
if self.architecture:
lines.append(f"**Architecture:** `{self.architecture}`")
lines.append("")
# Errors
if self.errors:
lines.append("### ❌ Errors")
for e in self.errors:
lines.append(f"- {e}")
lines.append("")
# Warnings
if self.warnings:
lines.append("### ⚠️ Warnings")
for w in self.warnings:
lines.append(f"- {w}")
lines.append("")
# Model details table
if self.models_info:
lines.append("### Model Details")
lines.append("| Model | Type | Hidden | Layers | Vocab | Params |")
lines.append("|-------|------|--------|--------|-------|--------|")
for m in self.models_info:
name = m.display_name
if len(name) > 35:
name = name[:32] + "..."
lines.append(
f"| {name} | `{m.model_type}` | {m.hidden_size} | "
f"{m.num_hidden_layers} | {m.vocab_size} | {m.param_estimate} |"
)
lines.append("")
# Suggestions
if self.suggestions:
lines.append("### 💡 Suggestions")
for s in self.suggestions:
lines.append(f"- {s}")
lines.append("")
# Merge methods
if self.merge_methods_available:
methods = ", ".join(f"`{m}`" for m in self.merge_methods_available)
lines.append(f"**Available merge methods:** {methods}")
lines.append("")
# Resource estimates
if self.estimated_ram_gb > 0:
lines.append(f"**Estimated RAM:** {self.estimated_ram_gb} GB")
lines.append(f"**Estimated time:** {self.estimated_merge_time}")
colab_tier = "Standard" if self.estimated_ram_gb <= 12 else "High-RAM" if self.estimated_ram_gb <= 48 else "A100 (Colab Pro+)"
lines.append(f"**Recommended Colab:** {colab_tier}")
lines.append("")
if self.suggested_base:
lines.append(f"**Suggested base model:** `{self.suggested_base}`")
if self.suggested_tokenizer:
lines.append(f"**Suggested tokenizer:** `{self.suggested_tokenizer}`")
return "\n".join(lines)
def check_compatibility(
model_ids: list[str],
token: Optional[str] = None,
) -> CompatibilityReport:
"""Check if a list of models are compatible for merging.
Args:
model_ids: List of HuggingFace model IDs
token: Optional HF API token for gated models
Returns:
CompatibilityReport with detailed analysis
"""
report = CompatibilityReport()
# Validate input
if len(model_ids) < 2:
report.compatible = False
report.errors.append("At least 2 models are required for merging.")
return report
if len(model_ids) > 10:
report.warnings.append("Merging more than 10 models is unusual and may produce poor results.")
# Fetch all model info
for mid in model_ids:
mid = mid.strip()
if not mid:
continue
info = fetch_model_info(mid, token=token)
report.models_info.append(info)
if info.error:
if info.gated:
report.warnings.append(f"`{mid}`: Gated model — provide HF token to verify compatibility")
else:
report.compatible = False
report.errors.append(f"`{mid}`: {info.error}")
# If we couldn't fetch any models, bail
valid_models = [m for m in report.models_info if not m.error]
if len(valid_models) < 2:
report.compatible = False
if not report.errors:
report.errors.append("Could not fetch enough model configs to verify compatibility.")
return report
# === ARCHITECTURE CHECKS ===
# 1. model_type must match
types = set(m.model_type for m in valid_models)
if len(types) > 1:
report.compatible = False
report.errors.append(
f"Architecture mismatch! Found: {', '.join(f'`{t}`' for t in types)}. "
f"All models must share the same architecture to merge."
)
return report
report.architecture = valid_models[0].model_type
# 2. hidden_size must match
hidden_sizes = set(m.hidden_size for m in valid_models if m.hidden_size > 0)
if len(hidden_sizes) > 1:
report.compatible = False
report.errors.append(
f"Hidden size mismatch: {', '.join(str(s) for s in hidden_sizes)}. "
f"Models must have the same hidden dimension."
)
# 3. intermediate_size must match (for most methods)
inter_sizes = set(m.intermediate_size for m in valid_models if m.intermediate_size > 0)
if len(inter_sizes) > 1:
report.compatible = False
report.errors.append(
f"Intermediate size mismatch: {', '.join(str(s) for s in inter_sizes)}. "
f"Required for DARE-TIES, SLERP, and Linear methods."
)
# 4. num_hidden_layers — warn if different
layer_counts = set(m.num_hidden_layers for m in valid_models if m.num_hidden_layers > 0)
if len(layer_counts) > 1:
report.warnings.append(
f"Layer count differs: {', '.join(str(l) for l in layer_counts)}. "
f"Passthrough/Frankenmerge can handle this, but DARE-TIES/SLERP/Linear require matching layers."
)
# 5. vocab_size — warn if different
vocab_sizes = set(m.vocab_size for m in valid_models if m.vocab_size > 0)
if len(vocab_sizes) > 1:
report.warnings.append(
f"Vocabulary size differs: {', '.join(str(v) for v in vocab_sizes)}. "
f"Use `tokenizer_source` to specify which tokenizer to keep."
)
# 6. num_attention_heads / num_key_value_heads
head_counts = set(m.num_attention_heads for m in valid_models if m.num_attention_heads > 0)
kv_head_counts = set(m.num_key_value_heads for m in valid_models if m.num_key_value_heads > 0)
if len(head_counts) > 1:
report.compatible = False
report.errors.append(
f"Attention head count mismatch: {', '.join(str(h) for h in head_counts)}."
)
if len(kv_head_counts) > 1:
report.warnings.append(
f"KV head count differs: {', '.join(str(h) for h in kv_head_counts)}. "
f"This may cause issues with GQA models."
)
# 7. trust_remote_code warning
needs_trust = [m.model_id for m in valid_models if m.trust_remote_code]
if needs_trust:
report.warnings.append(
f"Models requiring `trust_remote_code=True`: "
f"{', '.join(f'`{m}`' for m in needs_trust)}"
)
# === SUGGESTIONS ===
# Suggest base model (most downloaded or original base if detectable)
if valid_models:
# Prefer instruct/base versions, then most downloaded
base_candidates = sorted(
valid_models,
key=lambda m: (
"instruct" in m.model_id.lower() and "code" not in m.model_id.lower(),
-m.downloads,
),
)
report.suggested_base = base_candidates[0].model_id
report.suggestions.append(f"Use `{report.suggested_base}` as the base model")
# Suggest tokenizer source (largest vocab)
if vocab_sizes and len(vocab_sizes) > 1:
largest_vocab_model = max(valid_models, key=lambda m: m.vocab_size)
report.suggested_tokenizer = largest_vocab_model.model_id
report.suggestions.append(
f"Use tokenizer from `{report.suggested_tokenizer}` (largest vocab: {largest_vocab_model.vocab_size})"
)
elif valid_models:
report.suggested_tokenizer = report.suggested_base
# === AVAILABLE MERGE METHODS ===
n = len(valid_models)
methods = []
if report.compatible:
# Linear always works if architectures match
methods.append("linear")
# DARE-TIES needs matching layers
if len(layer_counts) <= 1:
methods.append("dare_ties")
methods.append("ties")
# SLERP only for 2 models
if n == 2 and len(layer_counts) <= 1:
methods.append("slerp")
# Task arithmetic needs a base
methods.append("task_arithmetic")
# Passthrough works even with different layer counts
methods.append("passthrough")
report.merge_methods_available = methods
# === RESOURCE ESTIMATES ===
max_size = max((m.size_bytes for m in valid_models if m.size_bytes > 0), default=0)
if max_size > 0:
# Merging needs roughly: all models loaded + output
total_model_bytes = sum(m.size_bytes for m in valid_models if m.size_bytes > 0)
# Rule of thumb: need models + 50% overhead
ram_needed = (total_model_bytes + max_size) * 1.3
report.estimated_ram_gb = round(ram_needed / (1024**3), 1)
# Time estimate based on total size
total_gb = total_model_bytes / (1024**3)
if total_gb < 10:
report.estimated_merge_time = "5-15 minutes"
elif total_gb < 30:
report.estimated_merge_time = "15-30 minutes"
elif total_gb < 60:
report.estimated_merge_time = "30-60 minutes"
else:
report.estimated_merge_time = "1-2+ hours"
return report
def quick_check(model_ids: list[str], token: Optional[str] = None) -> str:
"""Quick one-line compatibility check.
Returns a formatted string like:
"✅ Compatible (qwen2) | 3 models | ~32GB RAM | DARE-TIES, SLERP, Linear"
"""
report = check_compatibility(model_ids, token=token)
if not report.compatible:
errors = "; ".join(report.errors[:2])
return f"❌ {errors}"
methods = ", ".join(report.merge_methods_available[:3])
parts = [
f"{report.status_emoji} {report.status_text}",
f"Architecture: {report.architecture}",
f"{len(report.models_info)} models",
]
if report.estimated_ram_gb > 0:
parts.append(f"~{report.estimated_ram_gb}GB RAM")
parts.append(f"Methods: {methods}")
return " | ".join(parts)