Spaces:
Running
Running
| import os | |
| import shutil | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import zarr | |
| from PIL import Image | |
| from typing import Tuple, List | |
| from utils.config import config, get_logger | |
| from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer | |
| logger = get_logger("Engine") | |
| def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]: | |
| if not video_path: | |
| return "Please upload a video.", [] | |
| if os.path.exists(config.cache_dir): | |
| logger.info(f"Clearing old cache at {config.cache_dir}...") | |
| shutil.rmtree(config.cache_dir, ignore_errors=True) | |
| logger.info("Starting fast extraction process...") | |
| vidcap = cv2.VideoCapture(video_path) | |
| video_fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = max(1, int(video_fps / config.default_fps)) | |
| success, first_frame = vidcap.read() | |
| if not success: | |
| return "Failed to read video.", [] | |
| rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
| h, w, c = rgb_first.shape | |
| logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...") | |
| frame_cache = zarr.create_array( | |
| config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3 | |
| ) | |
| timestamps, count, frame_idx = [], 0, 0 | |
| while success: | |
| if count % frame_interval == 0: | |
| rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) | |
| frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0) | |
| timestamps.append(count / video_fps) | |
| frame_idx += 1 | |
| success, first_frame = vidcap.read() | |
| count += 1 | |
| vidcap.release() | |
| logger.info("Generating CLIP embeddings in batches...") | |
| all_embeddings = [] | |
| total_frames = frame_cache.shape[0] | |
| for i in range(0, total_frames, config.batch_size): | |
| batch_arrays = frame_cache[i : i + config.batch_size] | |
| batch_pil = [Image.fromarray(arr) for arr in batch_arrays] | |
| inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| # 🚨 BUGFIX: Manually extract and project the vision features | |
| vision_outputs = clip_model.vision_model(**inputs) | |
| features = clip_model.visual_projection(vision_outputs.pooler_output) | |
| normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist() | |
| all_embeddings.extend(normalized) | |
| logger.info("Indexing into ChromaDB...") | |
| ids = [f"frame_{i}" for i in range(total_frames)] | |
| metadatas = [{"timestamp": ts, "frame_idx": i} for i, ts in enumerate(timestamps)] | |
| global collection | |
| chroma_client.delete_collection(config.collection_name) | |
| collection = chroma_client.create_collection(config.collection_name) | |
| collection.add(embeddings=all_embeddings, metadatas=metadatas, ids=ids) | |
| sample_frames = [Image.fromarray(frame_cache[i]) for i in range(min(3, total_frames))] | |
| return f"Processed {total_frames} frames strictly on SSD cache.", sample_frames | |
| def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]: | |
| if collection.count() == 0: | |
| return "Please process a video first.", [] | |
| logger.info(f"Processing query: '{query}'") | |
| inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device) | |
| with torch.no_grad(): | |
| # 🚨 BUGFIX: Manually extract and project the text features | |
| text_outputs = clip_model.text_model(**inputs) | |
| text_features = clip_model.text_projection(text_outputs.pooler_output) | |
| text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist() | |
| results = collection.query(query_embeddings=text_embedding, n_results=3) | |
| frame_cache = zarr.open_array(config.cache_dir, mode="r") | |
| retrieved_images = [] | |
| for metadata in results['metadatas'][0]: | |
| img_array = frame_cache[int(metadata['frame_idx'])] | |
| retrieved_images.append(Image.fromarray(img_array)) | |
| logger.info("Generating VLM answer...") | |
| encoded_image = vlm_model.encode_image(retrieved_images[0]) | |
| answer = vlm_model.answer_question(encoded_image, query, vlm_tokenizer) | |
| return answer, retrieved_images |