File size: 11,605 Bytes
0249933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""Architecture compatibility checker for model merging."""

from dataclasses import dataclass, field
from typing import Optional
from .model_info import ModelInfo, fetch_model_info


@dataclass
class CompatibilityReport:
    """Result of compatibility checking between models."""
    compatible: bool = True
    errors: list[str] = field(default_factory=list)
    warnings: list[str] = field(default_factory=list)
    suggestions: list[str] = field(default_factory=list)
    models_info: list[ModelInfo] = field(default_factory=list)
    suggested_base: str = ""
    suggested_tokenizer: str = ""
    architecture: str = ""
    merge_methods_available: list[str] = field(default_factory=list)
    estimated_ram_gb: float = 0.0
    estimated_merge_time: str = ""

    @property
    def status_emoji(self) -> str:
        if not self.compatible:
            return "❌"
        elif self.warnings:
            return "⚠️"
        return "βœ…"

    @property
    def status_text(self) -> str:
        if not self.compatible:
            return "Incompatible β€” cannot merge"
        elif self.warnings:
            return "Compatible with warnings"
        return "Fully compatible"

    def to_markdown(self) -> str:
        """Generate a formatted markdown report."""
        lines = []

        # Header
        lines.append(f"## {self.status_emoji} Compatibility Report")
        lines.append("")

        if self.architecture:
            lines.append(f"**Architecture:** `{self.architecture}`")
            lines.append("")

        # Errors
        if self.errors:
            lines.append("### ❌ Errors")
            for e in self.errors:
                lines.append(f"- {e}")
            lines.append("")

        # Warnings
        if self.warnings:
            lines.append("### ⚠️ Warnings")
            for w in self.warnings:
                lines.append(f"- {w}")
            lines.append("")

        # Model details table
        if self.models_info:
            lines.append("### Model Details")
            lines.append("| Model | Type | Hidden | Layers | Vocab | Params |")
            lines.append("|-------|------|--------|--------|-------|--------|")
            for m in self.models_info:
                name = m.display_name
                if len(name) > 35:
                    name = name[:32] + "..."
                lines.append(
                    f"| {name} | `{m.model_type}` | {m.hidden_size} | "
                    f"{m.num_hidden_layers} | {m.vocab_size} | {m.param_estimate} |"
                )
            lines.append("")

        # Suggestions
        if self.suggestions:
            lines.append("### πŸ’‘ Suggestions")
            for s in self.suggestions:
                lines.append(f"- {s}")
            lines.append("")

        # Merge methods
        if self.merge_methods_available:
            methods = ", ".join(f"`{m}`" for m in self.merge_methods_available)
            lines.append(f"**Available merge methods:** {methods}")
            lines.append("")

        # Resource estimates
        if self.estimated_ram_gb > 0:
            lines.append(f"**Estimated RAM:** {self.estimated_ram_gb} GB")
            lines.append(f"**Estimated time:** {self.estimated_merge_time}")
            colab_tier = "Standard" if self.estimated_ram_gb <= 12 else "High-RAM" if self.estimated_ram_gb <= 48 else "A100 (Colab Pro+)"
            lines.append(f"**Recommended Colab:** {colab_tier}")
            lines.append("")

        if self.suggested_base:
            lines.append(f"**Suggested base model:** `{self.suggested_base}`")
        if self.suggested_tokenizer:
            lines.append(f"**Suggested tokenizer:** `{self.suggested_tokenizer}`")

        return "\n".join(lines)


def check_compatibility(
    model_ids: list[str],
    token: Optional[str] = None,
) -> CompatibilityReport:
    """Check if a list of models are compatible for merging.

    Args:
        model_ids: List of HuggingFace model IDs
        token: Optional HF API token for gated models

    Returns:
        CompatibilityReport with detailed analysis
    """
    report = CompatibilityReport()

    # Validate input
    if len(model_ids) < 2:
        report.compatible = False
        report.errors.append("At least 2 models are required for merging.")
        return report

    if len(model_ids) > 10:
        report.warnings.append("Merging more than 10 models is unusual and may produce poor results.")

    # Fetch all model info
    for mid in model_ids:
        mid = mid.strip()
        if not mid:
            continue
        info = fetch_model_info(mid, token=token)
        report.models_info.append(info)

        if info.error:
            if info.gated:
                report.warnings.append(f"`{mid}`: Gated model β€” provide HF token to verify compatibility")
            else:
                report.compatible = False
                report.errors.append(f"`{mid}`: {info.error}")

    # If we couldn't fetch any models, bail
    valid_models = [m for m in report.models_info if not m.error]
    if len(valid_models) < 2:
        report.compatible = False
        if not report.errors:
            report.errors.append("Could not fetch enough model configs to verify compatibility.")
        return report

    # === ARCHITECTURE CHECKS ===

    # 1. model_type must match
    types = set(m.model_type for m in valid_models)
    if len(types) > 1:
        report.compatible = False
        report.errors.append(
            f"Architecture mismatch! Found: {', '.join(f'`{t}`' for t in types)}. "
            f"All models must share the same architecture to merge."
        )
        return report

    report.architecture = valid_models[0].model_type

    # 2. hidden_size must match
    hidden_sizes = set(m.hidden_size for m in valid_models if m.hidden_size > 0)
    if len(hidden_sizes) > 1:
        report.compatible = False
        report.errors.append(
            f"Hidden size mismatch: {', '.join(str(s) for s in hidden_sizes)}. "
            f"Models must have the same hidden dimension."
        )

    # 3. intermediate_size must match (for most methods)
    inter_sizes = set(m.intermediate_size for m in valid_models if m.intermediate_size > 0)
    if len(inter_sizes) > 1:
        report.compatible = False
        report.errors.append(
            f"Intermediate size mismatch: {', '.join(str(s) for s in inter_sizes)}. "
            f"Required for DARE-TIES, SLERP, and Linear methods."
        )

    # 4. num_hidden_layers β€” warn if different
    layer_counts = set(m.num_hidden_layers for m in valid_models if m.num_hidden_layers > 0)
    if len(layer_counts) > 1:
        report.warnings.append(
            f"Layer count differs: {', '.join(str(l) for l in layer_counts)}. "
            f"Passthrough/Frankenmerge can handle this, but DARE-TIES/SLERP/Linear require matching layers."
        )

    # 5. vocab_size β€” warn if different
    vocab_sizes = set(m.vocab_size for m in valid_models if m.vocab_size > 0)
    if len(vocab_sizes) > 1:
        report.warnings.append(
            f"Vocabulary size differs: {', '.join(str(v) for v in vocab_sizes)}. "
            f"Use `tokenizer_source` to specify which tokenizer to keep."
        )

    # 6. num_attention_heads / num_key_value_heads
    head_counts = set(m.num_attention_heads for m in valid_models if m.num_attention_heads > 0)
    kv_head_counts = set(m.num_key_value_heads for m in valid_models if m.num_key_value_heads > 0)
    if len(head_counts) > 1:
        report.compatible = False
        report.errors.append(
            f"Attention head count mismatch: {', '.join(str(h) for h in head_counts)}."
        )
    if len(kv_head_counts) > 1:
        report.warnings.append(
            f"KV head count differs: {', '.join(str(h) for h in kv_head_counts)}. "
            f"This may cause issues with GQA models."
        )

    # 7. trust_remote_code warning
    needs_trust = [m.model_id for m in valid_models if m.trust_remote_code]
    if needs_trust:
        report.warnings.append(
            f"Models requiring `trust_remote_code=True`: "
            f"{', '.join(f'`{m}`' for m in needs_trust)}"
        )

    # === SUGGESTIONS ===

    # Suggest base model (most downloaded or original base if detectable)
    if valid_models:
        # Prefer instruct/base versions, then most downloaded
        base_candidates = sorted(
            valid_models,
            key=lambda m: (
                "instruct" in m.model_id.lower() and "code" not in m.model_id.lower(),
                -m.downloads,
            ),
        )
        report.suggested_base = base_candidates[0].model_id
        report.suggestions.append(f"Use `{report.suggested_base}` as the base model")

    # Suggest tokenizer source (largest vocab)
    if vocab_sizes and len(vocab_sizes) > 1:
        largest_vocab_model = max(valid_models, key=lambda m: m.vocab_size)
        report.suggested_tokenizer = largest_vocab_model.model_id
        report.suggestions.append(
            f"Use tokenizer from `{report.suggested_tokenizer}` (largest vocab: {largest_vocab_model.vocab_size})"
        )
    elif valid_models:
        report.suggested_tokenizer = report.suggested_base

    # === AVAILABLE MERGE METHODS ===
    n = len(valid_models)
    methods = []

    if report.compatible:
        # Linear always works if architectures match
        methods.append("linear")

        # DARE-TIES needs matching layers
        if len(layer_counts) <= 1:
            methods.append("dare_ties")
            methods.append("ties")

        # SLERP only for 2 models
        if n == 2 and len(layer_counts) <= 1:
            methods.append("slerp")

        # Task arithmetic needs a base
        methods.append("task_arithmetic")

    # Passthrough works even with different layer counts
    methods.append("passthrough")

    report.merge_methods_available = methods

    # === RESOURCE ESTIMATES ===
    max_size = max((m.size_bytes for m in valid_models if m.size_bytes > 0), default=0)
    if max_size > 0:
        # Merging needs roughly: all models loaded + output
        total_model_bytes = sum(m.size_bytes for m in valid_models if m.size_bytes > 0)
        # Rule of thumb: need models + 50% overhead
        ram_needed = (total_model_bytes + max_size) * 1.3
        report.estimated_ram_gb = round(ram_needed / (1024**3), 1)

        # Time estimate based on total size
        total_gb = total_model_bytes / (1024**3)
        if total_gb < 10:
            report.estimated_merge_time = "5-15 minutes"
        elif total_gb < 30:
            report.estimated_merge_time = "15-30 minutes"
        elif total_gb < 60:
            report.estimated_merge_time = "30-60 minutes"
        else:
            report.estimated_merge_time = "1-2+ hours"

    return report


def quick_check(model_ids: list[str], token: Optional[str] = None) -> str:
    """Quick one-line compatibility check.

    Returns a formatted string like:
    "βœ… Compatible (qwen2) | 3 models | ~32GB RAM | DARE-TIES, SLERP, Linear"
    """
    report = check_compatibility(model_ids, token=token)

    if not report.compatible:
        errors = "; ".join(report.errors[:2])
        return f"❌ {errors}"

    methods = ", ".join(report.merge_methods_available[:3])
    parts = [
        f"{report.status_emoji} {report.status_text}",
        f"Architecture: {report.architecture}",
        f"{len(report.models_info)} models",
    ]
    if report.estimated_ram_gb > 0:
        parts.append(f"~{report.estimated_ram_gb}GB RAM")
    parts.append(f"Methods: {methods}")

    return " | ".join(parts)