| """ |
| SAM 2 Zero-Shot Segmentation Model |
| |
| This module implements zero-shot segmentation using SAM 2 with advanced |
| text prompting, visual grounding, and attention-based prompt generation. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional, Tuple, Union |
| import numpy as np |
| from PIL import Image |
| import clip |
| from segment_anything_2 import sam_model_registry, SamPredictor |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel |
| import cv2 |
|
|
|
|
| class SAM2ZeroShot(nn.Module): |
| """ |
| SAM 2 Zero-Shot Segmentation Model |
| |
| Performs zero-shot segmentation using SAM 2 with advanced text prompting |
| and visual grounding techniques. |
| """ |
| |
| def __init__( |
| self, |
| sam2_checkpoint: str, |
| clip_model_name: str = "ViT-B/32", |
| device: str = "cuda", |
| use_attention_maps: bool = True, |
| use_grounding_dino: bool = False, |
| temperature: float = 0.1 |
| ): |
| super().__init__() |
| self.device = device |
| self.temperature = temperature |
| self.use_attention_maps = use_attention_maps |
| self.use_grounding_dino = use_grounding_dino |
| |
| |
| self.sam2 = sam_model_registry["vit_h"](checkpoint=sam2_checkpoint) |
| self.sam2.to(device) |
| self.sam2_predictor = SamPredictor(self.sam2) |
| |
| |
| self.clip_model, self.clip_preprocess = clip.load(clip_model_name, device=device) |
| self.clip_model.eval() |
| |
| |
| if self.use_attention_maps: |
| self.clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") |
| self.clip_vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") |
| self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
| self.clip_text_model.to(device) |
| self.clip_vision_model.to(device) |
| |
| |
| self.advanced_prompts = { |
| "satellite": { |
| "building": [ |
| "satellite view of buildings", "aerial photograph of structures", |
| "overhead view of houses", "urban development from above", |
| "rooftop structures", "architectural features from space" |
| ], |
| "road": [ |
| "satellite view of roads", "aerial photograph of streets", |
| "overhead view of highways", "transportation network from above", |
| "paved surfaces", "road infrastructure from space" |
| ], |
| "vegetation": [ |
| "satellite view of vegetation", "aerial photograph of forests", |
| "overhead view of trees", "green areas from above", |
| "natural landscape", "plant life from space" |
| ], |
| "water": [ |
| "satellite view of water", "aerial photograph of lakes", |
| "overhead view of rivers", "water bodies from above", |
| "aquatic features", "water resources from space" |
| ] |
| }, |
| "fashion": { |
| "shirt": [ |
| "fashion photography of shirts", "clothing item top", |
| "apparel garment", "upper body clothing", |
| "casual wear", "formal attire top" |
| ], |
| "pants": [ |
| "fashion photography of pants", "lower body clothing", |
| "trousers garment", "leg wear", |
| "casual pants", "formal trousers" |
| ], |
| "dress": [ |
| "fashion photography of dresses", "full body garment", |
| "formal dress", "evening wear", |
| "casual dress", "party dress" |
| ], |
| "shoes": [ |
| "fashion photography of shoes", "footwear item", |
| "foot covering", "walking shoes", |
| "casual footwear", "formal shoes" |
| ] |
| }, |
| "robotics": { |
| "robot": [ |
| "robotics environment with robot", "automation equipment", |
| "mechanical arm", "industrial robot", |
| "automated system", "robotic device" |
| ], |
| "tool": [ |
| "robotics environment with tools", "industrial equipment", |
| "mechanical tools", "work equipment", |
| "hand tools", "power tools" |
| ], |
| "safety": [ |
| "robotics environment with safety equipment", "protective gear", |
| "safety helmet", "safety vest", |
| "protective clothing", "safety equipment" |
| ] |
| } |
| } |
| |
| |
| self.prompt_strategies = { |
| "descriptive": lambda x: f"a clear image showing {x}", |
| "contextual": lambda x: f"in a typical environment, {x}", |
| "detailed": lambda x: f"high quality photograph of {x} with clear details", |
| "contrastive": lambda x: f"{x} standing out from the background" |
| } |
| |
| def generate_attention_maps( |
| self, |
| image: torch.Tensor, |
| text_prompts: List[str] |
| ) -> torch.Tensor: |
| """Generate attention maps using CLIP's cross-attention.""" |
| if not self.use_attention_maps: |
| return None |
| |
| |
| text_inputs = self.clip_tokenizer( |
| text_prompts, |
| padding=True, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| |
| image_inputs = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
| |
| |
| with torch.no_grad(): |
| text_outputs = self.clip_text_model(**text_inputs, output_attentions=True) |
| vision_outputs = self.clip_vision_model(image_inputs, output_attentions=True) |
| |
| |
| cross_attention = text_outputs.cross_attentions[-1] |
| attention_maps = cross_attention.mean(dim=1) |
| |
| return attention_maps |
| |
| def extract_attention_points( |
| self, |
| attention_maps: torch.Tensor, |
| num_points: int = 5 |
| ) -> List[Tuple[int, int]]: |
| """Extract points from attention maps for SAM 2 prompting.""" |
| if attention_maps is None: |
| return [] |
| |
| |
| h, w = attention_maps.shape[-2:] |
| attention_maps = F.interpolate( |
| attention_maps.unsqueeze(0), |
| size=(h, w), |
| mode='bilinear' |
| ).squeeze(0) |
| |
| |
| points = [] |
| for i in range(min(num_points, attention_maps.shape[0])): |
| attention_map = attention_maps[i] |
| max_idx = torch.argmax(attention_map) |
| y, x = max_idx // w, max_idx % w |
| points.append((int(x), int(y))) |
| |
| return points |
| |
| def generate_enhanced_prompts( |
| self, |
| domain: str, |
| class_names: List[str] |
| ) -> List[str]: |
| """Generate enhanced prompts using multiple strategies.""" |
| enhanced_prompts = [] |
| |
| for class_name in class_names: |
| if domain in self.advanced_prompts and class_name in self.advanced_prompts[domain]: |
| base_prompts = self.advanced_prompts[domain][class_name] |
| |
| |
| enhanced_prompts.extend(base_prompts) |
| |
| |
| for strategy_name, strategy_func in self.prompt_strategies.items(): |
| for base_prompt in base_prompts[:2]: |
| enhanced_prompt = strategy_func(base_prompt) |
| enhanced_prompts.append(enhanced_prompt) |
| else: |
| |
| enhanced_prompts.append(class_name) |
| enhanced_prompts.append(f"object: {class_name}") |
| |
| return enhanced_prompts |
| |
| def compute_text_image_similarity( |
| self, |
| image: torch.Tensor, |
| text_prompts: List[str] |
| ) -> torch.Tensor: |
| """Compute similarity between image and text prompts.""" |
| |
| text_tokens = clip.tokenize(text_prompts).to(self.device) |
| |
| with torch.no_grad(): |
| text_features = self.clip_model.encode_text(text_tokens) |
| text_features = F.normalize(text_features, dim=-1) |
| |
| |
| image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
| image_features = self.clip_model.encode_image(image_input) |
| image_features = F.normalize(image_features, dim=-1) |
| |
| |
| similarity = torch.matmul(image_features, text_features.T) / self.temperature |
| |
| return similarity |
| |
| def generate_sam2_prompts( |
| self, |
| image: torch.Tensor, |
| domain: str, |
| class_names: List[str] |
| ) -> List[Dict]: |
| """Generate comprehensive SAM 2 prompts for zero-shot segmentation.""" |
| prompts = [] |
| |
| |
| text_prompts = self.generate_enhanced_prompts(domain, class_names) |
| |
| |
| similarities = self.compute_text_image_similarity(image, text_prompts) |
| |
| |
| attention_maps = self.generate_attention_maps(image, text_prompts) |
| attention_points = self.extract_attention_points(attention_maps) |
| |
| |
| for i, class_name in enumerate(class_names): |
| class_prompts = [] |
| |
| |
| class_text_indices = [] |
| for j, prompt in enumerate(text_prompts): |
| if class_name.lower() in prompt.lower(): |
| class_text_indices.append(j) |
| |
| if class_text_indices: |
| |
| class_similarities = similarities[0, class_text_indices] |
| best_idx = torch.argmax(class_similarities) |
| best_similarity = class_similarities[best_idx] |
| |
| if best_similarity > 0.2: |
| |
| if attention_points: |
| for point in attention_points[:3]: |
| prompts.append({ |
| 'type': 'point', |
| 'data': point, |
| 'label': 1, |
| 'class': class_name, |
| 'confidence': best_similarity.item(), |
| 'source': 'attention' |
| }) |
| |
| |
| h, w = image.shape[-2:] |
| center_point = [w // 2, h // 2] |
| prompts.append({ |
| 'type': 'point', |
| 'data': center_point, |
| 'label': 1, |
| 'class': class_name, |
| 'confidence': best_similarity.item(), |
| 'source': 'center' |
| }) |
| |
| |
| if best_similarity > 0.4: |
| box = [w // 4, h // 4, 3 * w // 4, 3 * h // 4] |
| prompts.append({ |
| 'type': 'box', |
| 'data': box, |
| 'class': class_name, |
| 'confidence': best_similarity.item(), |
| 'source': 'similarity' |
| }) |
| |
| return prompts |
| |
| def segment( |
| self, |
| image: torch.Tensor, |
| domain: str, |
| class_names: List[str] |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Perform zero-shot segmentation. |
| |
| Args: |
| image: Input image tensor [C, H, W] |
| domain: Domain name (satellite, fashion, robotics) |
| class_names: List of class names to segment |
| |
| Returns: |
| Dictionary with masks for each class |
| """ |
| |
| if isinstance(image, torch.Tensor): |
| image_np = image.permute(1, 2, 0).cpu().numpy() |
| else: |
| image_np = image |
| |
| |
| self.sam2_predictor.set_image(image_np) |
| |
| |
| prompts = self.generate_sam2_prompts(image, domain, class_names) |
| |
| results = {} |
| |
| for prompt in prompts: |
| class_name = prompt['class'] |
| |
| if prompt['type'] == 'point': |
| point = prompt['data'] |
| label = prompt['label'] |
| |
| |
| masks, scores, logits = self.sam2_predictor.predict( |
| point_coords=np.array([point]), |
| point_labels=np.array([label]), |
| multimask_output=True |
| ) |
| |
| |
| best_mask_idx = np.argmax(scores) |
| mask = torch.from_numpy(masks[best_mask_idx]).float() |
| |
| |
| if prompt['confidence'] > 0.2: |
| if class_name not in results: |
| results[class_name] = mask |
| else: |
| |
| results[class_name] = torch.max(results[class_name], mask) |
| |
| elif prompt['type'] == 'box': |
| box = prompt['data'] |
| |
| |
| masks, scores, logits = self.sam2_predictor.predict( |
| box=np.array(box), |
| multimask_output=True |
| ) |
| |
| |
| best_mask_idx = np.argmax(scores) |
| mask = torch.from_numpy(masks[best_mask_idx]).float() |
| |
| |
| if prompt['confidence'] > 0.3: |
| if class_name not in results: |
| results[class_name] = mask |
| else: |
| |
| results[class_name] = torch.max(results[class_name], mask) |
| |
| return results |
| |
| def forward( |
| self, |
| image: torch.Tensor, |
| domain: str, |
| class_names: List[str] |
| ) -> Dict[str, torch.Tensor]: |
| """Forward pass.""" |
| return self.segment(image, domain, class_names) |
|
|
|
|
| class ZeroShotEvaluator: |
| """Evaluator for zero-shot segmentation.""" |
| |
| def __init__(self): |
| self.metrics = {} |
| |
| def compute_iou(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
| """Compute Intersection over Union.""" |
| intersection = (pred_mask & gt_mask).sum() |
| union = (pred_mask | gt_mask).sum() |
| return (intersection / union).item() if union > 0 else 0.0 |
| |
| def compute_dice(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor) -> float: |
| """Compute Dice coefficient.""" |
| intersection = (pred_mask & gt_mask).sum() |
| total = pred_mask.sum() + gt_mask.sum() |
| return (2 * intersection / total).item() if total > 0 else 0.0 |
| |
| def evaluate( |
| self, |
| predictions: Dict[str, torch.Tensor], |
| ground_truth: Dict[str, torch.Tensor] |
| ) -> Dict[str, float]: |
| """Evaluate zero-shot segmentation results.""" |
| results = {} |
| |
| for class_name in ground_truth.keys(): |
| if class_name in predictions: |
| pred_mask = predictions[class_name] > 0.5 |
| gt_mask = ground_truth[class_name] > 0.5 |
| |
| iou = self.compute_iou(pred_mask, gt_mask) |
| dice = self.compute_dice(pred_mask, gt_mask) |
| |
| results[f"{class_name}_iou"] = iou |
| results[f"{class_name}_dice"] = dice |
| |
| |
| if results: |
| results['mean_iou'] = np.mean([v for k, v in results.items() if 'iou' in k]) |
| results['mean_dice'] = np.mean([v for k, v in results.items() if 'dice' in k]) |
| |
| return results |