Forgekit / forgekit /model_info.py
AIencoder's picture
Rename model_info.py to forgekit/model_info.py
d71cb95 verified
"""HuggingFace Hub API wrapper for model discovery and info retrieval."""
import json
import time
from dataclasses import dataclass, field
from typing import Optional
from functools import lru_cache
import requests
HF_API = "https://huggingface.co/api"
_session = requests.Session()
_session.headers.update({"Accept": "application/json"})
# Simple in-memory cache with TTL
_cache: dict[str, tuple[float, any]] = {}
CACHE_TTL = 300 # 5 minutes
def _cached_get(url: str, token: Optional[str] = None, ttl: int = CACHE_TTL) -> dict:
"""GET with caching and rate-limit handling."""
now = time.time()
if url in _cache and (now - _cache[url][0]) < ttl:
return _cache[url][1]
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
resp = _session.get(url, headers=headers, timeout=15)
if resp.status_code == 429:
retry = int(resp.headers.get("Retry-After", 5))
time.sleep(retry)
resp = _session.get(url, headers=headers, timeout=15)
resp.raise_for_status()
data = resp.json()
_cache[url] = (now, data)
return data
@dataclass
class ModelInfo:
"""Parsed model information from HF Hub."""
model_id: str
model_type: str = "unknown"
architectures: list[str] = field(default_factory=list)
vocab_size: int = 0
hidden_size: int = 0
intermediate_size: int = 0
num_hidden_layers: int = 0
num_attention_heads: int = 0
num_key_value_heads: int = 0
max_position_embeddings: int = 0
torch_dtype: str = "unknown"
pipeline_tag: str = ""
tags: list[str] = field(default_factory=list)
downloads: int = 0
likes: int = 0
size_bytes: int = 0
gated: bool = False
private: bool = False
trust_remote_code: bool = False
error: Optional[str] = None
@property
def param_estimate(self) -> str:
"""Rough parameter count estimate based on architecture."""
if self.size_bytes > 0:
# Rough: model files in bf16 ≈ 2 bytes per param
params = self.size_bytes / 2
if params > 1e9:
return f"{params/1e9:.1f}B"
elif params > 1e6:
return f"{params/1e6:.0f}M"
return "unknown"
@property
def arch_signature(self) -> str:
"""Unique signature for architecture matching."""
return f"{self.model_type}|{self.hidden_size}|{self.intermediate_size}"
@property
def display_name(self) -> str:
"""Short display name (without org prefix)."""
return self.model_id.split("/")[-1] if "/" in self.model_id else self.model_id
@property
def ram_estimate_gb(self) -> float:
"""Estimated RAM needed for merging (roughly 2.5x model size for bf16 merge)."""
if self.size_bytes > 0:
return round(self.size_bytes * 2.5 / (1024**3), 1)
return 0.0
def to_dict(self) -> dict:
return {
"model_id": self.model_id,
"model_type": self.model_type,
"architectures": self.architectures,
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"intermediate_size": self.intermediate_size,
"num_hidden_layers": self.num_hidden_layers,
"num_attention_heads": self.num_attention_heads,
"torch_dtype": self.torch_dtype,
"pipeline_tag": self.pipeline_tag,
"downloads": self.downloads,
"likes": self.likes,
"param_estimate": self.param_estimate,
"ram_estimate_gb": self.ram_estimate_gb,
"gated": self.gated,
"private": self.private,
}
def fetch_model_info(model_id: str, token: Optional[str] = None) -> ModelInfo:
"""Fetch comprehensive model information from HF Hub.
Args:
model_id: Full model ID (e.g., "Qwen/Qwen2.5-Coder-7B-Instruct")
token: Optional HF API token for gated/private models
Returns:
ModelInfo dataclass with all available information
"""
info = ModelInfo(model_id=model_id)
# Fetch main model info
try:
data = _cached_get(f"{HF_API}/models/{model_id}", token=token)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
info.error = "Gated or private model — HF token required"
info.gated = True
elif e.response.status_code == 404:
info.error = f"Model not found: {model_id}"
else:
info.error = f"API error: {e.response.status_code}"
return info
except Exception as e:
info.error = f"Connection error: {str(e)}"
return info
# Parse basic metadata
info.pipeline_tag = data.get("pipeline_tag", "")
info.tags = data.get("tags", [])
info.downloads = data.get("downloads", 0)
info.likes = data.get("likes", 0)
info.gated = data.get("gated", False) not in (False, None)
info.private = data.get("private", False)
# Parse config (architecture details)
config = data.get("config", {})
if config:
info.model_type = config.get("model_type", "unknown")
info.architectures = config.get("architectures", [])
# Fetch full config.json for detailed architecture info
# (the API endpoint only returns basic config fields)
try:
full_config = _cached_get(
f"https://huggingface.co/{model_id}/resolve/main/config.json",
token=token,
)
info.model_type = full_config.get("model_type", info.model_type)
info.architectures = full_config.get("architectures", info.architectures)
info.vocab_size = full_config.get("vocab_size", 0)
info.hidden_size = full_config.get("hidden_size", 0)
info.intermediate_size = full_config.get("intermediate_size", 0)
info.num_hidden_layers = full_config.get("num_hidden_layers", 0)
info.num_attention_heads = full_config.get("num_attention_heads", 0)
info.num_key_value_heads = full_config.get("num_key_value_heads", 0)
info.max_position_embeddings = full_config.get("max_position_embeddings", 0)
info.torch_dtype = full_config.get("torch_dtype", "unknown")
if "auto_map" in full_config:
info.trust_remote_code = True
except Exception:
# Fall back to basic config from API
if config:
info.vocab_size = config.get("vocab_size", 0)
info.hidden_size = config.get("hidden_size", 0)
else:
info.error = "Could not fetch config.json — model may need trust_remote_code=True"
info.trust_remote_code = True
# Estimate total model size from siblings (files)
siblings = data.get("siblings", [])
total_size = 0
for f in siblings:
fname = f.get("rfilename", "")
size = f.get("size", 0) or 0
# Count only model weight files
if any(fname.endswith(ext) for ext in
[".safetensors", ".bin", ".pt", ".pth", ".gguf"]):
total_size += size
info.size_bytes = total_size
return info
def search_models(
query: str = "",
author: str = "",
architecture: str = "",
limit: int = 20,
sort: str = "downloads",
token: Optional[str] = None,
) -> list[dict]:
"""Search HuggingFace Hub for models.
Args:
query: Search query string
author: Filter by author/organization
architecture: Filter by model_type (e.g., "llama", "qwen2")
limit: Max results to return
sort: Sort by "downloads", "likes", "created", "modified"
token: Optional HF API token
Returns:
List of dicts with basic model info
"""
params = {
"limit": min(limit, 100),
"sort": sort,
"direction": -1,
"config": True,
}
if query:
params["search"] = query
if author:
params["author"] = author
url = f"{HF_API}/models"
try:
data = _cached_get(
f"{url}?{'&'.join(f'{k}={v}' for k, v in params.items())}",
token=token,
ttl=60, # shorter cache for search
)
except Exception as e:
return [{"error": str(e)}]
results = []
for m in data:
config = m.get("config", {}) or {}
model_type = config.get("model_type", "")
# Filter by architecture if specified
if architecture and model_type.lower() != architecture.lower():
continue
results.append({
"model_id": m.get("modelId", ""),
"model_type": model_type,
"pipeline_tag": m.get("pipeline_tag", ""),
"downloads": m.get("downloads", 0),
"likes": m.get("likes", 0),
"tags": m.get("tags", [])[:5],
})
return results[:limit]
def get_popular_base_models(architecture: str = "", token: Optional[str] = None) -> list[dict]:
"""Get popular base models for a given architecture type.
Useful for suggesting base_model in merge configs.
"""
# Common base models by architecture
known_bases = {
"llama": [
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-70B-Instruct",
"meta-llama/Llama-2-7b-hf",
],
"mistral": [
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
],
"qwen2": [
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
],
"gemma2": [
"google/gemma-2-9b-it",
"google/gemma-2-27b-it",
],
"phi3": [
"microsoft/Phi-3-mini-4k-instruct",
"microsoft/Phi-3-medium-4k-instruct",
],
}
if architecture.lower() in known_bases:
return [{"model_id": m} for m in known_bases[architecture.lower()]]
# Fallback: search for popular instruct models
return search_models(
query=f"{architecture} instruct",
limit=5,
sort="downloads",
token=token,
)