|
|
"""HuggingFace Hub API wrapper for model discovery and info retrieval.""" |
|
|
|
|
|
import json |
|
|
import time |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
from functools import lru_cache |
|
|
|
|
|
import requests |
|
|
|
|
|
HF_API = "https://huggingface.co/api" |
|
|
_session = requests.Session() |
|
|
_session.headers.update({"Accept": "application/json"}) |
|
|
|
|
|
|
|
|
_cache: dict[str, tuple[float, any]] = {} |
|
|
CACHE_TTL = 300 |
|
|
|
|
|
|
|
|
def _cached_get(url: str, token: Optional[str] = None, ttl: int = CACHE_TTL) -> dict: |
|
|
"""GET with caching and rate-limit handling.""" |
|
|
now = time.time() |
|
|
if url in _cache and (now - _cache[url][0]) < ttl: |
|
|
return _cache[url][1] |
|
|
|
|
|
headers = {} |
|
|
if token: |
|
|
headers["Authorization"] = f"Bearer {token}" |
|
|
|
|
|
resp = _session.get(url, headers=headers, timeout=15) |
|
|
|
|
|
if resp.status_code == 429: |
|
|
retry = int(resp.headers.get("Retry-After", 5)) |
|
|
time.sleep(retry) |
|
|
resp = _session.get(url, headers=headers, timeout=15) |
|
|
|
|
|
resp.raise_for_status() |
|
|
data = resp.json() |
|
|
_cache[url] = (now, data) |
|
|
return data |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelInfo: |
|
|
"""Parsed model information from HF Hub.""" |
|
|
model_id: str |
|
|
model_type: str = "unknown" |
|
|
architectures: list[str] = field(default_factory=list) |
|
|
vocab_size: int = 0 |
|
|
hidden_size: int = 0 |
|
|
intermediate_size: int = 0 |
|
|
num_hidden_layers: int = 0 |
|
|
num_attention_heads: int = 0 |
|
|
num_key_value_heads: int = 0 |
|
|
max_position_embeddings: int = 0 |
|
|
torch_dtype: str = "unknown" |
|
|
pipeline_tag: str = "" |
|
|
tags: list[str] = field(default_factory=list) |
|
|
downloads: int = 0 |
|
|
likes: int = 0 |
|
|
size_bytes: int = 0 |
|
|
gated: bool = False |
|
|
private: bool = False |
|
|
trust_remote_code: bool = False |
|
|
error: Optional[str] = None |
|
|
|
|
|
@property |
|
|
def param_estimate(self) -> str: |
|
|
"""Rough parameter count estimate based on architecture.""" |
|
|
if self.size_bytes > 0: |
|
|
|
|
|
params = self.size_bytes / 2 |
|
|
if params > 1e9: |
|
|
return f"{params/1e9:.1f}B" |
|
|
elif params > 1e6: |
|
|
return f"{params/1e6:.0f}M" |
|
|
return "unknown" |
|
|
|
|
|
@property |
|
|
def arch_signature(self) -> str: |
|
|
"""Unique signature for architecture matching.""" |
|
|
return f"{self.model_type}|{self.hidden_size}|{self.intermediate_size}" |
|
|
|
|
|
@property |
|
|
def display_name(self) -> str: |
|
|
"""Short display name (without org prefix).""" |
|
|
return self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id |
|
|
|
|
|
@property |
|
|
def ram_estimate_gb(self) -> float: |
|
|
"""Estimated RAM needed for merging (roughly 2.5x model size for bf16 merge).""" |
|
|
if self.size_bytes > 0: |
|
|
return round(self.size_bytes * 2.5 / (1024**3), 1) |
|
|
return 0.0 |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"model_type": self.model_type, |
|
|
"architectures": self.architectures, |
|
|
"vocab_size": self.vocab_size, |
|
|
"hidden_size": self.hidden_size, |
|
|
"intermediate_size": self.intermediate_size, |
|
|
"num_hidden_layers": self.num_hidden_layers, |
|
|
"num_attention_heads": self.num_attention_heads, |
|
|
"torch_dtype": self.torch_dtype, |
|
|
"pipeline_tag": self.pipeline_tag, |
|
|
"downloads": self.downloads, |
|
|
"likes": self.likes, |
|
|
"param_estimate": self.param_estimate, |
|
|
"ram_estimate_gb": self.ram_estimate_gb, |
|
|
"gated": self.gated, |
|
|
"private": self.private, |
|
|
} |
|
|
|
|
|
|
|
|
def fetch_model_info(model_id: str, token: Optional[str] = None) -> ModelInfo: |
|
|
"""Fetch comprehensive model information from HF Hub. |
|
|
|
|
|
Args: |
|
|
model_id: Full model ID (e.g., "Qwen/Qwen2.5-Coder-7B-Instruct") |
|
|
token: Optional HF API token for gated/private models |
|
|
|
|
|
Returns: |
|
|
ModelInfo dataclass with all available information |
|
|
""" |
|
|
info = ModelInfo(model_id=model_id) |
|
|
|
|
|
|
|
|
try: |
|
|
data = _cached_get(f"{HF_API}/models/{model_id}", token=token) |
|
|
except requests.exceptions.HTTPError as e: |
|
|
if e.response.status_code == 401: |
|
|
info.error = "Gated or private model — HF token required" |
|
|
info.gated = True |
|
|
elif e.response.status_code == 404: |
|
|
info.error = f"Model not found: {model_id}" |
|
|
else: |
|
|
info.error = f"API error: {e.response.status_code}" |
|
|
return info |
|
|
except Exception as e: |
|
|
info.error = f"Connection error: {str(e)}" |
|
|
return info |
|
|
|
|
|
|
|
|
info.pipeline_tag = data.get("pipeline_tag", "") |
|
|
info.tags = data.get("tags", []) |
|
|
info.downloads = data.get("downloads", 0) |
|
|
info.likes = data.get("likes", 0) |
|
|
info.gated = data.get("gated", False) not in (False, None) |
|
|
info.private = data.get("private", False) |
|
|
|
|
|
|
|
|
config = data.get("config", {}) |
|
|
if config: |
|
|
info.model_type = config.get("model_type", "unknown") |
|
|
info.architectures = config.get("architectures", []) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
full_config = _cached_get( |
|
|
f"https://huggingface.co/{model_id}/resolve/main/config.json", |
|
|
token=token, |
|
|
) |
|
|
info.model_type = full_config.get("model_type", info.model_type) |
|
|
info.architectures = full_config.get("architectures", info.architectures) |
|
|
info.vocab_size = full_config.get("vocab_size", 0) |
|
|
info.hidden_size = full_config.get("hidden_size", 0) |
|
|
info.intermediate_size = full_config.get("intermediate_size", 0) |
|
|
info.num_hidden_layers = full_config.get("num_hidden_layers", 0) |
|
|
info.num_attention_heads = full_config.get("num_attention_heads", 0) |
|
|
info.num_key_value_heads = full_config.get("num_key_value_heads", 0) |
|
|
info.max_position_embeddings = full_config.get("max_position_embeddings", 0) |
|
|
info.torch_dtype = full_config.get("torch_dtype", "unknown") |
|
|
|
|
|
if "auto_map" in full_config: |
|
|
info.trust_remote_code = True |
|
|
except Exception: |
|
|
|
|
|
if config: |
|
|
info.vocab_size = config.get("vocab_size", 0) |
|
|
info.hidden_size = config.get("hidden_size", 0) |
|
|
else: |
|
|
info.error = "Could not fetch config.json — model may need trust_remote_code=True" |
|
|
info.trust_remote_code = True |
|
|
|
|
|
|
|
|
siblings = data.get("siblings", []) |
|
|
total_size = 0 |
|
|
for f in siblings: |
|
|
fname = f.get("rfilename", "") |
|
|
size = f.get("size", 0) or 0 |
|
|
|
|
|
if any(fname.endswith(ext) for ext in |
|
|
[".safetensors", ".bin", ".pt", ".pth", ".gguf"]): |
|
|
total_size += size |
|
|
info.size_bytes = total_size |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
def search_models( |
|
|
query: str = "", |
|
|
author: str = "", |
|
|
architecture: str = "", |
|
|
limit: int = 20, |
|
|
sort: str = "downloads", |
|
|
token: Optional[str] = None, |
|
|
) -> list[dict]: |
|
|
"""Search HuggingFace Hub for models. |
|
|
|
|
|
Args: |
|
|
query: Search query string |
|
|
author: Filter by author/organization |
|
|
architecture: Filter by model_type (e.g., "llama", "qwen2") |
|
|
limit: Max results to return |
|
|
sort: Sort by "downloads", "likes", "created", "modified" |
|
|
token: Optional HF API token |
|
|
|
|
|
Returns: |
|
|
List of dicts with basic model info |
|
|
""" |
|
|
params = { |
|
|
"limit": min(limit, 100), |
|
|
"sort": sort, |
|
|
"direction": -1, |
|
|
"config": True, |
|
|
} |
|
|
if query: |
|
|
params["search"] = query |
|
|
if author: |
|
|
params["author"] = author |
|
|
|
|
|
url = f"{HF_API}/models" |
|
|
try: |
|
|
data = _cached_get( |
|
|
f"{url}?{'&'.join(f'{k}={v}' for k, v in params.items())}", |
|
|
token=token, |
|
|
ttl=60, |
|
|
) |
|
|
except Exception as e: |
|
|
return [{"error": str(e)}] |
|
|
|
|
|
results = [] |
|
|
for m in data: |
|
|
config = m.get("config", {}) or {} |
|
|
model_type = config.get("model_type", "") |
|
|
|
|
|
|
|
|
if architecture and model_type.lower() != architecture.lower(): |
|
|
continue |
|
|
|
|
|
results.append({ |
|
|
"model_id": m.get("modelId", ""), |
|
|
"model_type": model_type, |
|
|
"pipeline_tag": m.get("pipeline_tag", ""), |
|
|
"downloads": m.get("downloads", 0), |
|
|
"likes": m.get("likes", 0), |
|
|
"tags": m.get("tags", [])[:5], |
|
|
}) |
|
|
|
|
|
return results[:limit] |
|
|
|
|
|
|
|
|
def get_popular_base_models(architecture: str = "", token: Optional[str] = None) -> list[dict]: |
|
|
"""Get popular base models for a given architecture type. |
|
|
|
|
|
Useful for suggesting base_model in merge configs. |
|
|
""" |
|
|
|
|
|
known_bases = { |
|
|
"llama": [ |
|
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
|
"meta-llama/Llama-3.1-70B-Instruct", |
|
|
"meta-llama/Llama-2-7b-hf", |
|
|
], |
|
|
"mistral": [ |
|
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
|
"mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
|
], |
|
|
"qwen2": [ |
|
|
"Qwen/Qwen2.5-7B-Instruct", |
|
|
"Qwen/Qwen2.5-14B-Instruct", |
|
|
"Qwen/Qwen2.5-3B-Instruct", |
|
|
"Qwen/Qwen2.5-72B-Instruct", |
|
|
], |
|
|
"gemma2": [ |
|
|
"google/gemma-2-9b-it", |
|
|
"google/gemma-2-27b-it", |
|
|
], |
|
|
"phi3": [ |
|
|
"microsoft/Phi-3-mini-4k-instruct", |
|
|
"microsoft/Phi-3-medium-4k-instruct", |
|
|
], |
|
|
} |
|
|
|
|
|
if architecture.lower() in known_bases: |
|
|
return [{"model_id": m} for m in known_bases[architecture.lower()]] |
|
|
|
|
|
|
|
|
return search_models( |
|
|
query=f"{architecture} instruct", |
|
|
limit=5, |
|
|
sort="downloads", |
|
|
token=token, |
|
|
) |
|
|
|