import torch import chromadb from transformers import CLIPProcessor, CLIPModel, AutoModelForCausalLM, AutoTokenizer from utils.config import config, get_logger logger = get_logger("Models") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing models on: {device.upper()}") # 1. Load CLIP logger.info(f"Loading CLIP ({config.clip_model_id})...") clip_processor = CLIPProcessor.from_pretrained(config.clip_model_id) clip_model = CLIPModel.from_pretrained(config.clip_model_id).to(device) # 2. Initialize ChromaDB logger.info("Initializing ChromaDB...") chroma_client = chromadb.Client() try: chroma_client.delete_collection(config.collection_name) except Exception: pass collection = chroma_client.create_collection(name=config.collection_name) # --- THE MONKEY PATCH --- # Intercept the breaking change in Transformers v4.45+ for custom models _orig_getattr = torch.nn.Module.__getattr__ def _patched_getattr(self, name): if name == "all_tied_weights_keys": return {} return _orig_getattr(self, name) torch.nn.Module.__getattr__ = _patched_getattr # ------------------------ # 3. Load VLM logger.info(f"Loading VLM ({config.vlm_model_id})...") vlm_model = AutoModelForCausalLM.from_pretrained( config.vlm_model_id, trust_remote_code=True, revision=config.vlm_revision ).to(device) vlm_tokenizer = AutoTokenizer.from_pretrained(config.vlm_model_id, revision=config.vlm_revision)