Yanlin Zhang commited on
Commit
899327c
·
1 Parent(s): 543ad60

add app.py

Browse files
Files changed (3) hide show
  1. README.md +21 -0
  2. app.py +401 -0
  3. requirements.txt +7 -0
README.md CHANGED
@@ -11,3 +11,24 @@ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ ## SAM3 Vehicle Trajectory Space
16
+
17
+ This Space turns `facebook/sam3` into a ready-to-use pipeline for extracting
18
+ small- and large-vehicle trajectories from aerial surveillance videos.
19
+
20
+ ### Quick start
21
+
22
+ 1. Authenticate with Hugging Face to access the gated SAM3 checkpoint:
23
+ ```bash
24
+ hf auth login
25
+ ```
26
+ 2. Upload an aerial MP4/MOV clip. The app automatically sends the prompts
27
+ `small-vehicle` and `large-vehicle` to SAM3, overlays the resulting masks,
28
+ and links detections over time to form trajectories.
29
+ 3. Download the rendered video and inspect the per-track summary table.
30
+
31
+ The UI exposes stride, resize, and frame-limit controls so you can trade off
32
+ latency versus coverage depending on the clip length. All heavy lifting (frame
33
+ decoding, segmentation, mask rendering, trajectory stitching) happens on the
34
+ Space so you only need to provide the footage.
app.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vehicle trajectory extractor powered by SAM3.
3
+
4
+ The app takes an aerial video, segments small and large vehicles frame-by-frame
5
+ with text prompts (`small-vehicle`, `large-vehicle`), and draws their
6
+ trajectories on top of the footage.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import math
12
+ import os
13
+ import tempfile
14
+ import uuid
15
+ from dataclasses import dataclass
16
+ from typing import Dict, List, Sequence, Tuple
17
+
18
+ import cv2
19
+ import gradio as gr
20
+ import numpy as np
21
+ from PIL import Image
22
+ import torch
23
+ from transformers import AutoImageProcessor, AutoModel
24
+
25
+ # -----------------------------------------------------------------------------
26
+ # Configuration
27
+ # -----------------------------------------------------------------------------
28
+
29
+ MODEL_ID = "facebook/sam3"
30
+ TEXT_PROMPTS = ["small-vehicle", "large-vehicle"]
31
+ MIN_MASK_PIXELS = 150 # filter spurious detections
32
+ MAX_TRACK_GAP = 3 # frames
33
+ DEFAULT_FRAME_STRIDE = 5
34
+ MAX_PROCESSED_FRAMES = 720
35
+
36
+ CLASS_COLORS: Dict[str, Tuple[int, int, int]] = {
37
+ "small-vehicle": (20, 148, 245), # RGB
38
+ "large-vehicle": (255, 120, 30),
39
+ }
40
+
41
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
43
+
44
+ # -----------------------------------------------------------------------------
45
+ # Model + processor
46
+ # -----------------------------------------------------------------------------
47
+
48
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
49
+ model = AutoModel.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to(DEVICE)
50
+ model.eval()
51
+
52
+
53
+ # -----------------------------------------------------------------------------
54
+ # Tracking utilities
55
+ # -----------------------------------------------------------------------------
56
+
57
+ @dataclass
58
+ class Track:
59
+ track_id: int
60
+ label: str
61
+ points: List[Tuple[int, float, float]]
62
+ last_frame: int
63
+ score: float | None
64
+
65
+
66
+ def _post_process(outputs, height: int, width: int):
67
+ target_sizes = [(height, width)]
68
+
69
+ if hasattr(processor, "post_process_instance_segmentation"):
70
+ return processor.post_process_instance_segmentation(
71
+ outputs=outputs,
72
+ target_sizes=target_sizes,
73
+ threshold=0.35,
74
+ mask_threshold=0.4,
75
+ overlap_mask_area_threshold=0.5,
76
+ )[0]
77
+
78
+ if hasattr(processor, "post_process_semantic_segmentation"):
79
+ segmentation = processor.post_process_semantic_segmentation(
80
+ outputs=outputs,
81
+ target_sizes=target_sizes,
82
+ )[0]
83
+ return {
84
+ "masks": segmentation.unsqueeze(0),
85
+ "scores": torch.ones(1),
86
+ "labels": torch.zeros(1, dtype=torch.int64),
87
+ }
88
+
89
+ raise gr.Error(
90
+ "This version of transformers does not expose SAM3 post-processing helpers. "
91
+ "Please ensure transformers>=4.46.0 is installed."
92
+ )
93
+
94
+
95
+ def _extract_detections(frame_rgb: np.ndarray) -> List[Dict]:
96
+ pil_image = Image.fromarray(frame_rgb)
97
+ detections: List[Dict] = []
98
+
99
+ for label in TEXT_PROMPTS:
100
+ inputs = processor(images=pil_image, text=label, return_tensors="pt")
101
+ inputs = {
102
+ k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v)
103
+ for k, v in inputs.items()
104
+ }
105
+
106
+ with torch.inference_mode():
107
+ outputs = model(**inputs)
108
+
109
+ processed = _post_process(outputs, pil_image.height, pil_image.width)
110
+ masks = processed.get("masks", [])
111
+ scores = processed.get("scores", [None] * len(masks))
112
+
113
+ for mask_tensor, score in zip(masks, scores):
114
+ mask_np = mask_tensor.squeeze().detach().cpu().numpy()
115
+ if mask_np.ndim == 3:
116
+ mask_np = mask_np[0]
117
+
118
+ binary_mask = mask_np > 0.5
119
+ area = int(binary_mask.sum())
120
+ if area < MIN_MASK_PIXELS:
121
+ continue
122
+
123
+ ys, xs = np.nonzero(binary_mask)
124
+ if len(xs) == 0:
125
+ continue
126
+
127
+ centroid = (float(xs.mean()), float(ys.mean()))
128
+ detections.append(
129
+ {
130
+ "label": label,
131
+ "mask": binary_mask,
132
+ "score": float(score) if score is not None else None,
133
+ "centroid": centroid,
134
+ "area": area,
135
+ }
136
+ )
137
+
138
+ return detections
139
+
140
+
141
+ def _update_tracks(
142
+ tracks: List[Track],
143
+ detections: Sequence[Dict],
144
+ frame_idx: int,
145
+ max_distance: float,
146
+ ) -> None:
147
+ for detection in detections:
148
+ centroid = np.array(detection["centroid"])
149
+ best_track = None
150
+ best_distance = math.inf
151
+
152
+ for track in tracks:
153
+ if track.label != detection["label"]:
154
+ continue
155
+ if frame_idx - track.last_frame > MAX_TRACK_GAP:
156
+ continue
157
+
158
+ prev_point = np.array(track.points[-1][1:])
159
+ dist = np.linalg.norm(centroid - prev_point)
160
+ if dist < best_distance and dist <= max_distance:
161
+ best_distance = dist
162
+ best_track = track
163
+
164
+ if best_track:
165
+ best_track.points.append((frame_idx, *detection["centroid"]))
166
+ best_track.last_frame = frame_idx
167
+ best_track.score = detection["score"]
168
+ else:
169
+ new_track = Track(
170
+ track_id=len(tracks) + 1,
171
+ label=detection["label"],
172
+ points=[(frame_idx, *detection["centroid"])],
173
+ last_frame=frame_idx,
174
+ score=detection["score"],
175
+ )
176
+ tracks.append(new_track)
177
+
178
+
179
+ def _blend_mask(frame: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int], alpha: float = 0.45):
180
+ overlay = frame.copy()
181
+ overlay[mask] = (1 - alpha) * overlay[mask] + alpha * np.array(color, dtype=np.float32)
182
+ return overlay
183
+
184
+
185
+ def _draw_annotations(
186
+ frame_rgb: np.ndarray,
187
+ detections: Sequence[Dict],
188
+ tracks: Sequence[Track],
189
+ frame_idx: int,
190
+ ):
191
+ annotated = frame_rgb.astype(np.float32)
192
+
193
+ for det in detections:
194
+ color_rgb = CLASS_COLORS.get(det["label"], (255, 255, 255))
195
+ color_bgr = tuple(int(c) for c in reversed(color_rgb))
196
+
197
+ annotated = _blend_mask(annotated, det["mask"], color_rgb)
198
+
199
+ cx, cy = det["centroid"]
200
+ cv2.circle(annotated, (int(cx), int(cy)), 4, color_bgr, -1)
201
+ cv2.putText(
202
+ annotated,
203
+ det["label"],
204
+ (int(cx) + 4, int(cy) - 4),
205
+ cv2.FONT_HERSHEY_SIMPLEX,
206
+ 0.4,
207
+ color_bgr,
208
+ 1,
209
+ cv2.LINE_AA,
210
+ )
211
+
212
+ for track in tracks:
213
+ if len(track.points) < 2:
214
+ continue
215
+ if track.points[-1][0] < frame_idx - MAX_TRACK_GAP:
216
+ continue
217
+
218
+ color_rgb = CLASS_COLORS.get(track.label, (255, 255, 255))
219
+ color_bgr = tuple(int(c) for c in reversed(color_rgb))
220
+ pts = [
221
+ (int(x), int(y))
222
+ for (f_idx, x, y) in track.points
223
+ if f_idx <= frame_idx
224
+ ]
225
+
226
+ for i in range(1, len(pts)):
227
+ cv2.line(annotated, pts[i - 1], pts[i], color_bgr, 2, cv2.LINE_AA)
228
+
229
+ cv2.circle(annotated, pts[-1], 5, color_bgr, -1)
230
+
231
+ return np.clip(annotated, 0, 255).astype(np.uint8)
232
+
233
+
234
+ def _summarize_tracks(tracks: Sequence[Track]) -> List[Dict]:
235
+ summary = []
236
+ for track in tracks:
237
+ if len(track.points) < 2:
238
+ continue
239
+
240
+ distances = []
241
+ for (prev_frame, x1, y1), (curr_frame, x2, y2) in zip(track.points, track.points[1:]):
242
+ distances.append(math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2))
243
+
244
+ summary.append(
245
+ {
246
+ "track_id": track.track_id,
247
+ "label": track.label,
248
+ "frames": len(track.points),
249
+ "start_frame": track.points[0][0],
250
+ "end_frame": track.points[-1][0],
251
+ "path_px": round(float(sum(distances)), 2),
252
+ }
253
+ )
254
+ return summary
255
+
256
+
257
+ # -----------------------------------------------------------------------------
258
+ # Video processing
259
+ # -----------------------------------------------------------------------------
260
+
261
+ def analyze_video(
262
+ video_path: str,
263
+ frame_stride: int = DEFAULT_FRAME_STRIDE,
264
+ max_frames: int = MAX_PROCESSED_FRAMES,
265
+ resize_long_edge: int = 1280,
266
+ ) -> Tuple[str, List[Dict]]:
267
+ if not video_path:
268
+ raise gr.Error("Please upload an aerial video (MP4, MOV, ...).")
269
+
270
+ capture = cv2.VideoCapture(video_path)
271
+ if not capture.isOpened():
272
+ raise gr.Error("Unable to read the uploaded video.")
273
+
274
+ fps = capture.get(cv2.CAP_PROP_FPS) or 15
275
+ width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
276
+ height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
277
+ diag = math.sqrt(width**2 + height**2)
278
+ max_assign_distance = 0.04 * diag
279
+
280
+ processed_frames = []
281
+ tracks: List[Track] = []
282
+
283
+ frame_index = 0
284
+ processed_count = 0
285
+
286
+ while processed_count < max_frames:
287
+ ret, frame_bgr = capture.read()
288
+ if not ret:
289
+ break
290
+
291
+ if frame_index % frame_stride != 0:
292
+ frame_index += 1
293
+ continue
294
+
295
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
296
+ frame_rgb = _resize_long_edge(frame_rgb, resize_long_edge)
297
+
298
+ detections = _extract_detections(frame_rgb)
299
+ _update_tracks(tracks, detections, frame_index, max_assign_distance)
300
+ annotated = _draw_annotations(frame_rgb, detections, tracks, frame_index)
301
+
302
+ processed_frames.append(cv2.cvtColor(annotated, cv2.COLOR_RGB2BGR))
303
+ processed_count += 1
304
+ frame_index += 1
305
+
306
+ capture.release()
307
+
308
+ if not processed_frames:
309
+ raise gr.Error("No frames were processed. Try lowering the stride or uploading a different video.")
310
+
311
+ output_path = _write_video(processed_frames, fps / max(frame_stride, 1))
312
+ summary = _summarize_tracks(tracks)
313
+
314
+ return output_path, summary
315
+
316
+
317
+ def _resize_long_edge(frame_rgb: np.ndarray, target_long_edge: int) -> np.ndarray:
318
+ h, w, _ = frame_rgb.shape
319
+ long_edge = max(h, w)
320
+ if long_edge <= target_long_edge:
321
+ return frame_rgb
322
+
323
+ scale = target_long_edge / long_edge
324
+ new_size = (int(w * scale), int(h * scale))
325
+ resized = cv2.resize(frame_rgb, new_size, interpolation=cv2.INTER_AREA)
326
+ return resized
327
+
328
+
329
+ def _write_video(frames: Sequence[np.ndarray], fps: float) -> str:
330
+ height, width, _ = frames[0].shape
331
+ tmp_path = os.path.join(tempfile.gettempdir(), f"sam3-trajectories-{uuid.uuid4().hex}.mp4")
332
+
333
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
334
+ writer = cv2.VideoWriter(tmp_path, fourcc, max(fps, 1.0), (width, height))
335
+ for frame in frames:
336
+ writer.write(frame)
337
+ writer.release()
338
+ return tmp_path
339
+
340
+
341
+ # -----------------------------------------------------------------------------
342
+ # Gradio UI
343
+ # -----------------------------------------------------------------------------
344
+
345
+ with gr.Blocks(title="SAM3 Vehicle Trajectories") as demo:
346
+ gr.Markdown(
347
+ """
348
+ ### SAM3 for Vehicle Trajectories
349
+ 1. Upload an aerial surveillance video.
350
+ 2. The app prompts SAM3 with `small-vehicle` and `large-vehicle`.
351
+ 3. Segmentations are linked across frames to render motion trails.
352
+ """
353
+ )
354
+
355
+ with gr.Row():
356
+ video_input = gr.Video(label="Aerial video (MP4/MOV)")
357
+ controls = gr.Column()
358
+ with controls:
359
+ stride_slider = gr.Slider(
360
+ label="Frame stride",
361
+ minimum=1,
362
+ maximum=12,
363
+ value=DEFAULT_FRAME_STRIDE,
364
+ step=1,
365
+ info="Process one frame every N frames",
366
+ )
367
+ max_frames_slider = gr.Slider(
368
+ label="Max frames to process",
369
+ minimum=30,
370
+ maximum=1000,
371
+ value=MAX_PROCESSED_FRAMES,
372
+ step=10,
373
+ )
374
+ resize_slider = gr.Slider(
375
+ label="Resize longest edge (px)",
376
+ minimum=640,
377
+ maximum=1920,
378
+ value=1280,
379
+ step=40,
380
+ )
381
+
382
+ output_video = gr.Video(label="Overlay with trajectories")
383
+ track_table = gr.Dataframe(
384
+ headers=["track_id", "label", "frames", "start_frame", "end_frame", "path_px"],
385
+ datatype=["number", "str", "number", "number", "number", "number"],
386
+ wrap=True,
387
+ label="Track summary",
388
+ )
389
+
390
+ run_button = gr.Button("Extract trajectories", variant="primary")
391
+
392
+ run_button.click(
393
+ fn=analyze_video,
394
+ inputs=[video_input, stride_slider, max_frames_slider, resize_slider],
395
+ outputs=[output_video, track_table],
396
+ api_name="analyze",
397
+ )
398
+
399
+
400
+ if __name__ == "__main__":
401
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.46.0
3
+ accelerate
4
+ gradio
5
+ pillow
6
+ opencv-python
7
+ numpy