Spaces:
Running
Running
| import torch | |
| import chromadb | |
| from transformers import CLIPProcessor, CLIPModel, AutoModelForCausalLM, AutoTokenizer | |
| from utils.config import config, get_logger | |
| logger = get_logger("Models") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Initializing models on: {device.upper()}") | |
| # 1. Load CLIP | |
| logger.info(f"Loading CLIP ({config.clip_model_id})...") | |
| clip_processor = CLIPProcessor.from_pretrained(config.clip_model_id) | |
| clip_model = CLIPModel.from_pretrained(config.clip_model_id).to(device) | |
| # 2. Initialize ChromaDB | |
| logger.info("Initializing ChromaDB...") | |
| chroma_client = chromadb.Client() | |
| try: | |
| chroma_client.delete_collection(config.collection_name) | |
| except Exception: | |
| pass | |
| collection = chroma_client.create_collection(name=config.collection_name) | |
| # --- THE MONKEY PATCH --- | |
| # Intercept the breaking change in Transformers v4.45+ for custom models | |
| _orig_getattr = torch.nn.Module.__getattr__ | |
| def _patched_getattr(self, name): | |
| if name == "all_tied_weights_keys": | |
| return {} | |
| return _orig_getattr(self, name) | |
| torch.nn.Module.__getattr__ = _patched_getattr | |
| # ------------------------ | |
| # 3. Load VLM | |
| logger.info(f"Loading VLM ({config.vlm_model_id})...") | |
| vlm_model = AutoModelForCausalLM.from_pretrained( | |
| config.vlm_model_id, trust_remote_code=True, revision=config.vlm_revision | |
| ).to(device) | |
| vlm_tokenizer = AutoTokenizer.from_pretrained(config.vlm_model_id, revision=config.vlm_revision) |