Athrael commited on
Commit
b34d493
·
verified ·
1 Parent(s): a4540ba

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +203 -11
  2. app.py +543 -0
  3. requirements.txt +21 -0
README.md CHANGED
@@ -1,13 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Eomt
3
- emoji: 😻
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ # EOMT Panoptic Segmentation App
2
+
3
+ A Gradio-based web application for interactive panoptic segmentation using the **EOMT (Encoder-only Mask Transformer)** model - a minimalist approach that repurposes a plain Vision Transformer (ViT) for image segmentation.
4
+
5
+ ## 🚀 Features
6
+
7
+ - **Interactive Web Interface**: User-friendly Gradio interface for uploading and processing images
8
+ - **Multiple Visualization Types**:
9
+ - Segmentation mask with color-coded segments (with error handling for empty masks)
10
+ - Overlay visualization on original image with transparency control
11
+ - Contour detection with distinct color coding for each segment
12
+ - Individual instance masks in grid layout with segment statistics
13
+ - Edge detection with improved boundary highlighting
14
+ - Segment isolation showing largest segment with detailed information
15
+ - Boundary density heatmap with gradient magnitude visualization
16
+ - **Real-time Processing**: Fast inference with CUDA support and proper GPU memory handling
17
+ - **Sample Images**: Built-in sample images with intuitive button selection
18
+ - **Detailed Analytics**: Comprehensive segment statistics and visual analysis
19
+ - **Enhanced Visualizations**: Improved color coding, titles, and statistical information for better analysis
20
+
21
+ ## 🧠 Model
22
+
23
+ This application uses the **EOMT (Encoder-only Mask Transformer)** model, as presented in the CVPR 2025 highlight paper ["Your ViT is Secretly an Image Segmentation Model"](https://www.tue-mps.org/eomt/):
24
+
25
+ - **Model ID**: `tue-mps/coco_panoptic_eomt_large_640`
26
+ - **Task**: Panoptic Segmentation
27
+ - **Dataset**: COCO Panoptic
28
+ - **Input Size**: 640x640 (automatically resized)
29
+ - **Architecture**: Plain Vision Transformer (ViT) repurposed for segmentation
30
+ - **Key Innovation**: No adapters, no decoders - just the ViT encoding image patches and segmentation queries as tokens
31
+ - **Performance**: Up to 4× faster than complex methods while maintaining state-of-the-art accuracy
32
+
33
+ ### Research Citation
34
+ ```
35
+ @inproceedings{kerssies2025eomt,
36
+ author = {Kerssies, Tommie and Cavagnero, Niccol\`{o} and Hermans, Alexander and Norouzi, Narges and Averta, Giuseppe and Leibe, Bastian and Dubbelman, Gijs and de Geus, Daan},
37
+ title = {Your ViT is Secretly an Image Segmentation Model},
38
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
39
+ year = {2025},
40
+ }
41
+ ```
42
+
43
+ 🔗 **Research Website**: [https://www.tue-mps.org/eomt/](https://www.tue-mps.org/eomt/)
44
+
45
+ ## 🛠️ Installation
46
+
47
+ 1. **Clone the repository** (if not already done):
48
+ ```bash
49
+ git clone https://github.com/athrael.soju/little-scripts.git
50
+ cd little-scripts/eomt_panoptic_seg
51
+ ```
52
+
53
+ 2. **Install dependencies**:
54
+
55
+ **For standard setup:**
56
+ ```bash
57
+ uv pip install -r requirements.txt
58
+ ```
59
+
60
+ **For newer NVIDIA GPUs (e.g., RTX 5090):**
61
+ ```bash
62
+ # Install PyTorch nightly for latest CUDA support
63
+ uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
64
+
65
+ # Then install remaining dependencies
66
+ uv pip install -r requirements.txt
67
+ ```
68
+
69
+ 3. **Run the application**:
70
+ ```bash
71
+ python app.py
72
+ ```
73
+
74
+ ## 📋 Requirements
75
+
76
+ - Python 3.10+
77
+ - CUDA-compatible GPU (recommended for faster inference)
78
+ - Minimum 4GB RAM
79
+ - Internet connection (for first-time model download)
80
+ - `uv` package manager (optional, for newer GPU support)
81
+
82
+ ### Dependencies
83
+ - **PyTorch**: Latest stable version (or nightly for newer GPUs like RTX 5090)
84
+ - **Transformers**: Latest development version from GitHub
85
+ - **Gradio**: For the web interface
86
+ - **OpenCV**: For image processing and contour detection
87
+ - **Matplotlib**: For visualization
88
+ - **NumPy & PIL**: For image handling
89
+ - **SciPy**: For advanced image processing operations
90
+
91
+ ## 🎯 Usage
92
+
93
+ 1. **Start the application**:
94
+ ```bash
95
+ python app.py
96
+ ```
97
+
98
+ 2. **Open your browser** and navigate to the provided URL (usually `http://localhost:7860`)
99
+
100
+ 3. **Upload an image** or select from sample images using the clickable buttons
101
+
102
+ 4. **Choose visualization type**:
103
+ - **Mask**: Color-coded segmentation mask with error handling for empty results
104
+ - **Overlay**: Transparent mask overlay on original image with optimized transparency
105
+ - **Contours**: Segment boundaries with distinct colors for each segment
106
+ - **Instance Masks**: Individual instance masks in a 3×3 grid showing segment IDs and pixel counts
107
+ - **Edge Detection**: Improved boundary highlighting with RGBA overlay technique
108
+ - **Segment Isolation**: Largest segment isolated with detailed statistics
109
+ - **Heatmap**: Boundary density visualization with gradient magnitude and proper labeling
110
+
111
+ 5. **View results** with interactive segment analysis and high-quality visualizations
112
+
113
+ ## 🔧 Configuration
114
+
115
+ The application automatically downloads the EOMT model on first run. Model files are cached locally by Hugging Face Transformers.
116
+
117
+ ### Model Configuration
118
+ - **Model**: `tue-mps/coco_panoptic_eomt_large_640`
119
+ - **Image Processor**: Auto-configured for the model
120
+ - **Inference Mode**: PyTorch inference mode for optimal performance
121
+ - **CUDA Support**: Automatic GPU detection and proper tensor handling
122
+
123
+ ## 📊 Features Detail
124
+
125
+ ### Visualization Types
126
+
127
+ 1. **Mask View**: Clean segmentation mask with color-coded segments using matplotlib's tab20 colormap, includes error handling for empty masks
128
+ 2. **Overlay View**: Weighted combination of original image and segmentation mask with optimized transparency
129
+ 3. **Contours View**: Precise segment boundaries with distinct colors for each segment using OpenCV and matplotlib colormaps
130
+ 4. **Instance Masks View**: Grid layout showing individual segments with segment IDs and pixel counts for detailed inspection
131
+ 5. **Edge Detection View**: Enhanced boundary detection using RGBA overlay technique with Canny edge detection
132
+ 6. **Segment Isolation View**: Focus on the largest segment with detailed statistics (segment ID and pixel count)
133
+ 7. **Heatmap View**: Boundary density analysis with gradient magnitude visualization and proper colorbar labeling
134
+
135
+ ### Interactive Features
136
+
137
+ - **Sample Images**: Pre-loaded COCO dataset images with large clickable thumbnails
138
+ - **Upload Support**: Drag-and-drop or click to upload custom images
139
+ - **Real-time Processing**: Fast inference with proper GPU memory management
140
+ - **High-Quality Output**: All visualizations rendered at 150 DPI for crisp results
141
+ - **Enhanced Information**: Detailed segment statistics, pixel counts, and gradient analysis
142
+
143
+ ## 🎨 Visualization Examples
144
+
145
+ The application provides multiple ways to visualize panoptic segmentation results:
146
+
147
+ - **Color-coded segments** using matplotlib's tab20 colormap for clear differentiation
148
+ - **Statistical analysis** showing segment distribution and coverage metrics
149
+ - **Contour detection** with OpenCV for precise boundary identification
150
+ - **Overlay techniques** with adjustable transparency for better visual understanding
151
+ - **Heatmap analysis** for boundary density and gradient visualization
152
+ - **Individual segment isolation** for detailed examination
153
+
154
+ ## 🚨 Troubleshooting
155
+
156
+ ### Common Issues
157
+
158
+ 1. **CUDA Tensor Errors**:
159
+ - The app automatically handles CUDA tensor to CPU conversion with robust error handling
160
+ - Ensure proper PyTorch installation with CUDA support
161
+ - Check GPU memory availability
162
+
163
+ 2. **Model Download Issues**:
164
+ - Ensure stable internet connection
165
+ - Check Hugging Face Hub access
166
+ - Verify sufficient disk space
167
+
168
+ 3. **Memory Issues**:
169
+ - Reduce image size before upload
170
+ - Close other applications
171
+ - Consider using CPU inference for large images
172
+
173
+ 4. **Performance Issues**:
174
+ - Use GPU if available
175
+ - Reduce image resolution
176
+ - Check system resources
177
+
178
+ 5. **PyTorch Installation Issues**:
179
+ - For newer GPUs (RTX 5090, etc.), use PyTorch nightly builds
180
+ - Install `uv` for better package management: `pip install uv`
181
+ - Check CUDA compatibility with your GPU
182
+ - Verify transformers installation from GitHub source
183
+
184
+ ## 📁 Project Structure
185
+
186
+ ```
187
+ eomt_panoptic_seg/
188
+ ├── app.py # Main application file
189
+ ├── requirements.txt # Python dependencies
190
+ └── README.md # This file
191
+ ```
192
+
193
+ ## 📄 License
194
+
195
+ Open source - feel free to use and modify as needed.
196
+
197
+ ## 🤝 Contributing
198
+
199
+ Contributions are welcome! Please feel free to submit a Pull Request.
200
+
201
  ---
 
 
 
 
 
 
 
 
 
 
202
 
203
+ **Note**: This application requires a GPU for optimal performance. CPU inference is supported but will be significantly slower. For newer NVIDIA GPUs, use PyTorch nightly builds for best compatibility. The app includes proper CUDA tensor handling to prevent memory-related errors.
204
+
205
+ **Research**: This implementation is based on the CVPR 2025 paper ["Your ViT is Secretly an Image Segmentation Model"](https://www.tue-mps.org/eomt/) by Kerssies et al., demonstrating that plain Vision Transformers can achieve state-of-the-art segmentation performance.
app.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import cv2
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import requests
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoImageProcessor, EomtForUniversalSegmentation
11
+
12
+ # Load model globally to avoid reloading
13
+ print("Loading model...")
14
+ model_id = "tue-mps/coco_panoptic_eomt_large_640"
15
+ processor = AutoImageProcessor.from_pretrained(model_id)
16
+ model = EomtForUniversalSegmentation.from_pretrained(model_id)
17
+
18
+ # Check for CUDA availability and move model to GPU if available
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model = model.to(device)
21
+ print(f"Model loaded successfully on {device}!")
22
+
23
+
24
+ def run_inference(image):
25
+ """Run panoptic segmentation inference"""
26
+ inputs = processor(images=image, return_tensors="pt")
27
+
28
+ # Move inputs to the same device as model
29
+ inputs = {k: v.to(device) for k, v in inputs.items()}
30
+
31
+ with torch.inference_mode():
32
+ outputs = model(**inputs)
33
+
34
+ target_sizes = [(image.height, image.width)]
35
+ preds = processor.post_process_panoptic_segmentation(
36
+ outputs, target_sizes=target_sizes
37
+ )
38
+
39
+ return preds[0]
40
+
41
+
42
+ def tensor_to_numpy(tensor):
43
+ """Convert tensor to numpy array, handling CUDA tensors"""
44
+ if isinstance(tensor, torch.Tensor):
45
+ return tensor.cpu().numpy()
46
+ return tensor
47
+
48
+
49
+ def visualize_mask(image, segmentation_mask):
50
+ """Show segmentation mask only"""
51
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
52
+
53
+ # Segmentation mask - convert tensor to numpy
54
+ mask_np = tensor_to_numpy(segmentation_mask)
55
+
56
+ if mask_np.max() > 0:
57
+ ax.imshow(mask_np, cmap="tab20")
58
+ else:
59
+ # If no segments, show a blank image with text
60
+ ax.imshow(np.zeros_like(mask_np), cmap="gray")
61
+ ax.text(
62
+ 0.5,
63
+ 0.5,
64
+ "No segments detected",
65
+ transform=ax.transAxes,
66
+ ha="center",
67
+ va="center",
68
+ fontsize=16,
69
+ color="red",
70
+ weight="bold",
71
+ )
72
+
73
+ ax.axis("off")
74
+
75
+ plt.tight_layout()
76
+
77
+ # Convert to PIL Image
78
+ buf = io.BytesIO()
79
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
80
+ buf.seek(0)
81
+ plt.close()
82
+
83
+ return Image.open(buf)
84
+
85
+
86
+ def visualize_overlay(image, segmentation_mask):
87
+ """Show segmentation overlay on original image"""
88
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
89
+
90
+ # Original image
91
+ ax.imshow(image)
92
+
93
+ # Overlay segmentation mask with transparency
94
+ mask_np = tensor_to_numpy(segmentation_mask)
95
+ ax.imshow(mask_np, cmap="tab20", alpha=0.6)
96
+
97
+ ax.axis("off")
98
+
99
+ plt.tight_layout()
100
+
101
+ # Convert to PIL Image
102
+ buf = io.BytesIO()
103
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
104
+ buf.seek(0)
105
+ plt.close()
106
+
107
+ return Image.open(buf)
108
+
109
+
110
+ def visualize_contours(image, segmentation_mask):
111
+ """Show contours of segments on original image"""
112
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
113
+
114
+ # Original image
115
+ ax.imshow(image)
116
+
117
+ # Convert mask to numpy and find contours
118
+ mask_np = tensor_to_numpy(segmentation_mask).astype(np.uint8)
119
+
120
+ # Find unique segments
121
+ unique_segments = np.unique(mask_np)
122
+
123
+ # Create a colormap for distinct colors
124
+ colors = plt.cm.tab20(np.linspace(0, 1, len(unique_segments)))
125
+
126
+ # Draw contours for each segment
127
+ for i, segment_id in enumerate(unique_segments):
128
+ if segment_id == 0: # Skip background
129
+ continue
130
+
131
+ # Create binary mask for this segment
132
+ binary_mask = (mask_np == segment_id).astype(np.uint8)
133
+
134
+ # Find contours
135
+ contours, _ = cv2.findContours(
136
+ binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
137
+ )
138
+
139
+ # Draw contours with unique color for each segment
140
+ for contour in contours:
141
+ if len(contour) > 2: # Only draw if contour has enough points
142
+ contour = contour.reshape(-1, 2)
143
+ ax.plot(
144
+ contour[:, 0],
145
+ contour[:, 1],
146
+ color=colors[i % len(colors)],
147
+ linewidth=2,
148
+ alpha=0.8,
149
+ )
150
+
151
+ ax.axis("off")
152
+
153
+ plt.tight_layout()
154
+
155
+ # Convert to PIL Image
156
+ buf = io.BytesIO()
157
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
158
+ buf.seek(0)
159
+ plt.close()
160
+
161
+ return Image.open(buf)
162
+
163
+
164
+ def visualize_instance_masks(image, segmentation_mask):
165
+ """Show individual instance masks in a grid"""
166
+ mask_np = tensor_to_numpy(segmentation_mask)
167
+ unique_segments, counts = np.unique(mask_np, return_counts=True)
168
+
169
+ # Get top 9 segments by size (excluding background)
170
+ non_bg_indices = unique_segments != 0
171
+ if np.any(non_bg_indices):
172
+ top_segments = unique_segments[non_bg_indices][
173
+ np.argsort(counts[non_bg_indices])[-9:]
174
+ ]
175
+ top_counts = counts[non_bg_indices][np.argsort(counts[non_bg_indices])[-9:]]
176
+ else:
177
+ top_segments = []
178
+ top_counts = []
179
+
180
+ fig, axes = plt.subplots(3, 3, figsize=(15, 15))
181
+ axes = axes.flatten()
182
+
183
+ for i, (segment_id, count) in enumerate(zip(top_segments, top_counts)):
184
+ binary_mask = (mask_np == segment_id).astype(float)
185
+ axes[i].imshow(binary_mask, cmap="Blues")
186
+ axes[i].set_title(
187
+ f"Segment {segment_id}\nPixels: {count}", fontsize=10, weight="bold"
188
+ )
189
+ axes[i].axis("off")
190
+
191
+ # Fill empty subplots with informative text
192
+ for i in range(len(top_segments), 9):
193
+ axes[i].axis("off")
194
+ if i == 0 and len(top_segments) == 0:
195
+ axes[i].text(
196
+ 0.5,
197
+ 0.5,
198
+ "No segments\ndetected",
199
+ transform=axes[i].transAxes,
200
+ ha="center",
201
+ va="center",
202
+ fontsize=12,
203
+ color="red",
204
+ weight="bold",
205
+ )
206
+
207
+ plt.tight_layout()
208
+
209
+ # Convert to PIL Image
210
+ buf = io.BytesIO()
211
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
212
+ buf.seek(0)
213
+ plt.close()
214
+
215
+ return Image.open(buf)
216
+
217
+
218
+ def visualize_edges(image, segmentation_mask):
219
+ """Show edge detection on segmentation boundaries"""
220
+ mask_np = tensor_to_numpy(segmentation_mask)
221
+
222
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
223
+
224
+ # Original image
225
+ ax.imshow(image)
226
+
227
+ # Edge detection on mask
228
+ if mask_np.max() > 0: # Check if mask has any segments
229
+ edges = cv2.Canny((mask_np * 255 / mask_np.max()).astype(np.uint8), 50, 150)
230
+
231
+ # Create a colored edge overlay
232
+ edge_overlay = np.zeros((*edges.shape, 4)) # RGBA
233
+ edge_overlay[edges > 0] = [1, 1, 0, 1] # Yellow with full alpha
234
+
235
+ # Overlay edges
236
+ ax.imshow(edge_overlay)
237
+ else:
238
+ # If no segments, just show the original image
239
+ ax.text(
240
+ 0.5,
241
+ 0.5,
242
+ "No segments detected",
243
+ transform=ax.transAxes,
244
+ ha="center",
245
+ va="center",
246
+ fontsize=16,
247
+ color="red",
248
+ weight="bold",
249
+ )
250
+
251
+ ax.axis("off")
252
+
253
+ plt.tight_layout()
254
+
255
+ # Convert to PIL Image
256
+ buf = io.BytesIO()
257
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
258
+ buf.seek(0)
259
+ plt.close()
260
+
261
+ return Image.open(buf)
262
+
263
+
264
+ def visualize_segment_isolation(image, segmentation_mask):
265
+ """Show the largest segment isolated from the rest"""
266
+ mask_np = tensor_to_numpy(segmentation_mask)
267
+ unique_segments, counts = np.unique(mask_np, return_counts=True)
268
+
269
+ # Find largest segment (excluding background)
270
+ non_bg_indices = unique_segments != 0
271
+ if np.any(non_bg_indices):
272
+ largest_segment = unique_segments[non_bg_indices][
273
+ np.argmax(counts[non_bg_indices])
274
+ ]
275
+ largest_count = counts[non_bg_indices][np.argmax(counts[non_bg_indices])]
276
+ else:
277
+ largest_segment = unique_segments[np.argmax(counts)]
278
+ largest_count = counts[np.argmax(counts)]
279
+
280
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
281
+
282
+ # Isolated segment
283
+ isolated_mask = (mask_np == largest_segment).astype(float)
284
+
285
+ if isolated_mask.max() > 0:
286
+ ax.imshow(isolated_mask, cmap="Reds")
287
+ ax.set_title(
288
+ f"Largest Segment (ID: {largest_segment}, Pixels: {largest_count})",
289
+ fontsize=14,
290
+ weight="bold",
291
+ pad=20,
292
+ )
293
+ else:
294
+ ax.imshow(np.zeros_like(isolated_mask), cmap="gray")
295
+ ax.text(
296
+ 0.5,
297
+ 0.5,
298
+ "No segments detected",
299
+ transform=ax.transAxes,
300
+ ha="center",
301
+ va="center",
302
+ fontsize=16,
303
+ color="red",
304
+ weight="bold",
305
+ )
306
+
307
+ ax.axis("off")
308
+
309
+ plt.tight_layout()
310
+
311
+ # Convert to PIL Image
312
+ buf = io.BytesIO()
313
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
314
+ buf.seek(0)
315
+ plt.close()
316
+
317
+ return Image.open(buf)
318
+
319
+
320
+ def visualize_heatmap(image, segmentation_mask):
321
+ """Show boundary density heatmap"""
322
+ mask_np = tensor_to_numpy(segmentation_mask)
323
+
324
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
325
+
326
+ if mask_np.max() > 0:
327
+ # Calculate gradient magnitude for boundary detection
328
+ gradient_magnitude = np.gradient(mask_np.astype(float))
329
+ gradient_magnitude = np.sqrt(
330
+ gradient_magnitude[0] ** 2 + gradient_magnitude[1] ** 2
331
+ )
332
+
333
+ # Boundary heatmap
334
+ im = ax.imshow(gradient_magnitude, cmap="hot")
335
+ ax.set_title("Boundary Density Heatmap", fontsize=14, weight="bold", pad=20)
336
+ ax.axis("off")
337
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Gradient Magnitude")
338
+ else:
339
+ # If no segments, show a blank heatmap
340
+ ax.imshow(np.zeros_like(mask_np), cmap="hot")
341
+ ax.text(
342
+ 0.5,
343
+ 0.5,
344
+ "No segments detected",
345
+ transform=ax.transAxes,
346
+ ha="center",
347
+ va="center",
348
+ fontsize=16,
349
+ color="red",
350
+ weight="bold",
351
+ )
352
+ ax.axis("off")
353
+
354
+ plt.tight_layout()
355
+
356
+ # Convert to PIL Image
357
+ buf = io.BytesIO()
358
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
359
+ buf.seek(0)
360
+ plt.close()
361
+
362
+ return Image.open(buf)
363
+
364
+
365
+ def create_visualization(image, viz_type):
366
+ """Create visualization based on selected type"""
367
+ if image is None:
368
+ return None
369
+
370
+ try:
371
+ # Run inference
372
+ prediction = run_inference(image)
373
+ segmentation_mask = prediction["segmentation"]
374
+
375
+ if viz_type == "Mask":
376
+ return visualize_mask(image, segmentation_mask)
377
+ elif viz_type == "Overlay":
378
+ return visualize_overlay(image, segmentation_mask)
379
+ elif viz_type == "Contours":
380
+ return visualize_contours(image, segmentation_mask)
381
+ elif viz_type == "Instance Masks":
382
+ return visualize_instance_masks(image, segmentation_mask)
383
+ elif viz_type == "Edge Detection":
384
+ return visualize_edges(image, segmentation_mask)
385
+ elif viz_type == "Segment Isolation":
386
+ return visualize_segment_isolation(image, segmentation_mask)
387
+ elif viz_type == "Heatmap":
388
+ return visualize_heatmap(image, segmentation_mask)
389
+ else:
390
+ # Default fallback
391
+ return visualize_mask(image, segmentation_mask)
392
+
393
+ except Exception as e:
394
+ print(f"Error in visualization: {e}")
395
+ # Return a simple error visualization
396
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
397
+ ax.text(
398
+ 0.5,
399
+ 0.5,
400
+ f"Error during processing:\n{str(e)}",
401
+ transform=ax.transAxes,
402
+ ha="center",
403
+ va="center",
404
+ fontsize=12,
405
+ color="red",
406
+ weight="bold",
407
+ )
408
+ ax.axis("off")
409
+
410
+ buf = io.BytesIO()
411
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
412
+ buf.seek(0)
413
+ plt.close()
414
+
415
+ return Image.open(buf)
416
+
417
+
418
+ def load_sample_image(img_path):
419
+ """Load a sample image from URL"""
420
+ try:
421
+ response = requests.get(img_path, stream=True)
422
+ response.raise_for_status()
423
+ return Image.open(response.raw)
424
+ except Exception as e:
425
+ print(f"Error loading image: {e}")
426
+ return None
427
+
428
+
429
+ # Create Gradio interface
430
+ def create_interface():
431
+ with gr.Blocks(
432
+ title="Panoptic Segmentation Visualizer", theme=gr.themes.Soft()
433
+ ) as demo:
434
+ gr.Markdown("""
435
+ # 🎨 Panoptic Segmentation Visualizer
436
+
437
+ Upload an image and select a visualization type to see different ways of viewing the panoptic segmentation results.
438
+ The model used is `tue-mps/coco_panoptic_eomt_large_640`.
439
+ """)
440
+
441
+ with gr.Row():
442
+ with gr.Column(scale=1):
443
+ image_input = gr.Image(label="Upload Image", type="pil", height=400)
444
+
445
+ viz_type = gr.Radio(
446
+ choices=[
447
+ "Mask",
448
+ "Overlay",
449
+ "Contours",
450
+ "Instance Masks",
451
+ "Edge Detection",
452
+ "Segment Isolation",
453
+ "Heatmap",
454
+ ],
455
+ label="Visualization Type",
456
+ value="Mask",
457
+ info="Choose how to visualize the segmentation results",
458
+ )
459
+
460
+ process_btn = gr.Button(
461
+ "🚀 Process Image", variant="primary", size="lg"
462
+ )
463
+
464
+ gr.Markdown("""
465
+ ### Visualization Types:
466
+ - **Mask**: Segmentation mask with color-coded segments
467
+ - **Overlay**: Transparent segmentation overlay on original image
468
+ - **Contours**: Segment boundaries outlined on original image
469
+ - **Instance Masks**: Individual instance masks in a grid (top 9 by size)
470
+ - **Edge Detection**: Segmentation boundaries highlighted in yellow
471
+ - **Segment Isolation**: Shows the largest segment isolated from the rest
472
+ - **Heatmap**: Boundary density visualization with color mapping
473
+ """)
474
+
475
+ with gr.Column(scale=2):
476
+ output_image = gr.Image(
477
+ label="Segmentation Result", type="pil", height=600
478
+ )
479
+
480
+ process_btn.click(
481
+ fn=create_visualization,
482
+ inputs=[image_input, viz_type],
483
+ outputs=output_image,
484
+ )
485
+
486
+ # Sample images with thumbnails
487
+ gr.Markdown("### 📸 Try with sample images:")
488
+
489
+ sample_images = [
490
+ ("http://images.cocodataset.org/val2017/000000039769.jpg", "Cats on Couch"),
491
+ ("http://images.cocodataset.org/val2017/000000397133.jpg", "Street Scene"),
492
+ ("http://images.cocodataset.org/val2017/000000037777.jpg", "Living Room"),
493
+ (
494
+ "http://images.cocodataset.org/val2017/000000174482.jpg",
495
+ "Person with Laptop",
496
+ ),
497
+ ("http://images.cocodataset.org/val2017/000000000785.jpg", "Dining Table"),
498
+ ]
499
+
500
+ def create_thumbnail_gallery():
501
+ """Create a gallery of clickable thumbnails"""
502
+ gallery_images = []
503
+ for img_url, img_name in sample_images:
504
+ try:
505
+ img = load_sample_image(img_url)
506
+ if img:
507
+ # Resize to thumbnail while maintaining aspect ratio
508
+ img.thumbnail((200, 200), Image.Resampling.LANCZOS)
509
+ gallery_images.append((img, img_name))
510
+ except Exception as e:
511
+ print(f"Failed to load {img_name}: {e}")
512
+ continue
513
+ return gallery_images
514
+
515
+ with gr.Row():
516
+ thumbnail_gallery = gr.Gallery(
517
+ value=create_thumbnail_gallery(),
518
+ label="Sample Images",
519
+ show_label=True,
520
+ elem_id="thumbnail_gallery",
521
+ columns=5,
522
+ rows=1,
523
+ object_fit="contain",
524
+ height=200,
525
+ allow_preview=False,
526
+ )
527
+
528
+ def select_from_gallery(evt: gr.SelectData):
529
+ """Handle gallery selection"""
530
+ selected_idx = evt.index
531
+ if selected_idx < len(sample_images):
532
+ img_url, _ = sample_images[selected_idx]
533
+ return load_sample_image(img_url)
534
+ return None
535
+
536
+ thumbnail_gallery.select(select_from_gallery, outputs=image_input)
537
+
538
+ return demo
539
+
540
+
541
+ if __name__ == "__main__":
542
+ demo = create_interface()
543
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML and Computer Vision
2
+ # torch # For newer Nvidia GPUs (e.g. RTX 5090) use uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
3
+ # torchvision
4
+
5
+ git+https://github.com/huggingface/transformers.git
6
+ pillow
7
+ numpy
8
+ opencv-python
9
+ scipy
10
+
11
+ # Gradio web interface
12
+ gradio
13
+
14
+ # Visualization
15
+ matplotlib
16
+
17
+ # HTTP requests
18
+ requests
19
+
20
+ # Pre-trained model
21
+ # This will download the model: tue-mps/coco_panoptic_eomt_large_640