manu02 commited on
Commit
c98fd6f
·
1 Parent(s): 65cc576

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +631 -0
app.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ 🖼️→📝 Image-to-Text Attention Visualizer (Custom Model)
4
+ - Loads your custom model via create_complete_model()
5
+ - Accepts an image, applies your transform, then calls:
6
+ model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
7
+ - Selector lists ONLY generated words (no prompt tokens).
8
+ - Viewer (single row) shows:
9
+ (1) original image,
10
+ (2) original + colored attention heatmap overlay,
11
+ (3) heatmap alone (colored).
12
+ - Heatmap is built from the first 1024 image tokens (32×32), then upscaled to the image size.
13
+ - Text block below shows word-level attention over generated tokens (no return_offsets_mapping used).
14
+ - Fixes deprecations: Matplotlib colormap API & Pillow mode inference.
15
+ """
16
+
17
+ import os
18
+ import re
19
+ import random
20
+ from typing import List, Tuple, Optional
21
+
22
+ import gradio as gr
23
+ import torch
24
+ import numpy as np
25
+ from PIL import Image
26
+ from safetensors.torch import load_model
27
+
28
+ # Optional: nicer colormap (Matplotlib >=3.7 API; no deprecation warnings)
29
+ try:
30
+ import matplotlib as mpl
31
+ _HAS_MPL = True
32
+ _COLORMAP = mpl.colormaps.get_cmap("magma")
33
+ except Exception:
34
+ _HAS_MPL = False
35
+ _COLORMAP = None
36
+
37
+ # ========= Your utilities & model =========
38
+ from utils.processing import image_transform, pil_from_path
39
+ from utils.complete_model import create_complete_model
40
+
41
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model = create_complete_model(device=DEVICE, attention_implementation="eager")
43
+ SAFETENSOR_PATH = "complete_model.safetensor"
44
+ try:
45
+ load_model(model, SAFETENSOR_PATH)
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}, continuing with uninitialized weights.")
48
+ model.eval()
49
+ device = DEVICE
50
+
51
+ # --- Grab tokenizer from your model ---
52
+ tokenizer = getattr(model, "tokenizer", None)
53
+ if tokenizer is None:
54
+ raise ValueError("Expected `model.tokenizer` to exist and be a HF-like tokenizer.")
55
+
56
+ # --- Fix PAD/EOS ambiguity (and resize embeddings if applicable) ---
57
+ needs_resize = False
58
+ pad_id = getattr(tokenizer, "pad_token_id", None)
59
+ eos_id = getattr(tokenizer, "eos_token_id", None)
60
+ if pad_id is None or (eos_id is not None and pad_id == eos_id):
61
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
62
+ needs_resize = True
63
+
64
+ # Try common resize hooks safely (only if your decoder actually uses tokenizer vocab)
65
+ if needs_resize:
66
+ resize_fns = [
67
+ getattr(getattr(model, "decoder", None), "resize_token_embeddings", None),
68
+ getattr(model, "resize_token_embeddings", None),
69
+ ]
70
+ for fn in resize_fns:
71
+ if callable(fn):
72
+ try:
73
+ fn(len(tokenizer))
74
+ break
75
+ except Exception:
76
+ # If your model doesn't need resizing (separate vocab), it's fine.
77
+ pass
78
+
79
+ # ========= Regex for words (words + punctuation) =========
80
+ WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
81
+
82
+ # ========= Model metadata (for slider ranges) =========
83
+ def model_heads_layers():
84
+ def _get(obj, *names, default=None):
85
+ for n in names:
86
+ if obj is None:
87
+ return default
88
+ if hasattr(obj, n):
89
+ try:
90
+ return int(getattr(obj, n))
91
+ except Exception:
92
+ return default
93
+ return default
94
+
95
+ cfg_candidates = [
96
+ getattr(model, "config", None),
97
+ getattr(getattr(model, "decoder", None), "config", None),
98
+ getattr(getattr(model, "lm_head", None), "config", None),
99
+ ]
100
+ L = H = None
101
+ for cfg in cfg_candidates:
102
+ if L is None:
103
+ L = _get(cfg, "num_hidden_layers", "n_layer", default=None)
104
+ if H is None:
105
+ H = _get(cfg, "num_attention_heads", "n_head", default=None)
106
+ if L is None: L = 12
107
+ if H is None: H = 12
108
+ return max(1, L), max(1, H)
109
+
110
+ # ========= Attention utils =========
111
+ def get_attention_for_token_layer(
112
+ attentions,
113
+ token_index,
114
+ layer_index,
115
+ batch_index=0,
116
+ head_index=0,
117
+ mean_across_layers=True,
118
+ mean_across_heads=True,
119
+ ):
120
+ """
121
+ `attentions`:
122
+ tuple length = #generated tokens
123
+ attentions[t] -> tuple over layers; each layer tensor is (batch, heads, q, k)
124
+ """
125
+ token_attention = attentions[token_index]
126
+
127
+ if mean_across_layers:
128
+ layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
129
+ else:
130
+ layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
131
+
132
+ batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
133
+
134
+ if mean_across_heads:
135
+ head_attention = batch_attention.mean(dim=0) # (q, k)
136
+ else:
137
+ head_attention = batch_attention[int(head_index)] # (q, k)
138
+
139
+ return head_attention.squeeze(0) # q==1 -> (k,)
140
+
141
+ # ========= Tokens → words mapping (no offset_mapping needed) =========
142
+ def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
143
+ """
144
+ Works with slow/fast tokenizers. No return_offsets_mapping.
145
+ Steps:
146
+ 1) detok token_ids
147
+ 2) regex-split words and get their char-end positions
148
+ 3) for each word-end (we), encode detok[:we] w/ add_special_tokens=False
149
+ last token index = len(prefix_ids) - 1
150
+ """
151
+ if not token_ids:
152
+ return [], []
153
+
154
+ toks = tokenizer.convert_ids_to_tokens(token_ids)
155
+ detok = tokenizer.convert_tokens_to_string(toks)
156
+
157
+ matches = list(re.finditer(WORD_RE, detok))
158
+ words = [m.group(0) for m in matches]
159
+ ends = [m.span()[1] for m in matches] # char end (exclusive)
160
+
161
+ word2tok: List[int] = []
162
+ for we in ends:
163
+ prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False)
164
+ if not prefix_ids:
165
+ word2tok.append(0)
166
+ continue
167
+ last_idx = len(prefix_ids) - 1
168
+ last_idx = max(0, min(last_idx, len(token_ids) - 1))
169
+ word2tok.append(last_idx)
170
+
171
+ return words, word2tok
172
+
173
+ def _strip_trailing_special(ids: List[int]) -> List[int]:
174
+ specials = set(getattr(tokenizer, "all_special_ids", []) or [])
175
+ j = len(ids)
176
+ while j > 0 and ids[j - 1] in specials:
177
+ j -= 1
178
+ return ids[:j]
179
+
180
+ # ========= Visualization (word-level for generated text) =========
181
+ def generate_word_visualization_gen_only(
182
+ words_gen: List[str],
183
+ word_ends_rel: List[int],
184
+ gen_attn_values: np.ndarray,
185
+ selected_token_rel_idx: int,
186
+ ) -> str:
187
+ """
188
+ words_gen: generated words only
189
+ word_ends_rel: last-token indices of each generated word (relative to generation)
190
+ gen_attn_values: length == len(gen_token_ids), attention over generated tokens only
191
+ (zeros for future tokens padded at the end)
192
+ """
193
+ if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0:
194
+ return (
195
+ "<div style='width:100%;'>"
196
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
197
+ " <div style='color:#ddd;'>No text attention values.</div>"
198
+ " </div>"
199
+ "</div>"
200
+ )
201
+
202
+ # compute word starts from ends (inclusive indexing)
203
+ starts = []
204
+ for i, end in enumerate(word_ends_rel):
205
+ if i == 0:
206
+ starts.append(0)
207
+ else:
208
+ starts.append(min(word_ends_rel[i - 1] + 1, end))
209
+
210
+ # sum attention per word
211
+ word_scores = []
212
+ T = len(gen_attn_values)
213
+ for i, end in enumerate(word_ends_rel):
214
+ start = starts[i]
215
+ if start > end:
216
+ start = end
217
+ s = max(0, min(start, T - 1))
218
+ e = max(0, min(end, T - 1))
219
+ if e < s:
220
+ s, e = e, s
221
+ word_scores.append(float(gen_attn_values[s:e + 1].sum()))
222
+
223
+ max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
224
+
225
+ # find selected word (contains selected token idx)
226
+ selected_word_idx = None
227
+ for i, end in enumerate(word_ends_rel):
228
+ if selected_token_rel_idx <= end:
229
+ selected_word_idx = i
230
+ break
231
+ if selected_word_idx is None and word_ends_rel:
232
+ selected_word_idx = len(word_ends_rel) - 1
233
+
234
+ spans = []
235
+ for i, w in enumerate(words_gen):
236
+ alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
237
+ bg = f"rgba(66,133,244,{alpha:.3f})"
238
+ border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
239
+ spans.append(
240
+ f"<span style='display:inline-block;background:{bg};border:{border};"
241
+ f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
242
+ f"{w}</span>"
243
+ )
244
+
245
+ return (
246
+ "<div style='width:100%;'>"
247
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
248
+ " <div style='white-space:normal;line-height:1.8;'>"
249
+ f" {''.join(spans)}"
250
+ " </div>"
251
+ " </div>"
252
+ "</div>"
253
+ )
254
+
255
+ # ========= Heatmap helpers for 1024 image tokens =========
256
+ def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray:
257
+ """
258
+ attn_1d: (k,) attention over keys for a given generation step; first 1024 are image tokens.
259
+ Returns a (32, 32) uint8 grayscale array.
260
+ """
261
+ # take first 1024 (image tokens); pad/truncate as needed
262
+ if attn_1d.shape[0] < img_token_len:
263
+ img_part = np.zeros(img_token_len, dtype=float)
264
+ img_part[: attn_1d.shape[0]] = attn_1d
265
+ else:
266
+ img_part = attn_1d[:img_token_len]
267
+
268
+ # normalize to [0,1]
269
+ mn, mx = float(img_part.min()), float(img_part.max())
270
+ denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0
271
+ norm = (img_part - mn) / denom
272
+
273
+ # return uint8 (0–255)
274
+ return (norm.reshape(side, side) * 255.0).astype(np.uint8)
275
+
276
+ def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image:
277
+ """
278
+ Convert (H,W) uint8 grayscale to RGB heatmap using matplotlib (if available) or a simple fallback.
279
+ """
280
+ if _HAS_MPL and _COLORMAP is not None:
281
+ colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8)
282
+ return Image.fromarray(colored) # Pillow infers RGB
283
+ else:
284
+ # Fallback: map grayscale to red-yellow (simple linear)
285
+ g = heatmap_u8.astype(np.float32) / 255.0
286
+ r = (g * 255.0).clip(0, 255).astype(np.uint8)
287
+ g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8)
288
+ b = np.zeros_like(r, dtype=np.uint8)
289
+ rgb = np.stack([r, g2, b], axis=-1)
290
+ return Image.fromarray(rgb) # Pillow infers RGB
291
+
292
+ def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
293
+ return img.resize(target_size, resample=Image.BILINEAR)
294
+
295
+ def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image:
296
+ """
297
+ Blend heatmap over original. alpha in [0,1].
298
+ """
299
+ if heatmap_rgb.size != orig.size:
300
+ heatmap_rgb = _resize_like(heatmap_rgb, orig.size)
301
+ base = orig.convert("RGBA")
302
+ overlay = heatmap_rgb.convert("RGBA")
303
+ # set global alpha
304
+ r, g, b = overlay.split()[:3]
305
+ a = Image.new("L", overlay.size, int(alpha * 255))
306
+ overlay = Image.merge("RGBA", (r, g, b, a))
307
+ return Image.alpha_composite(base, overlay).convert("RGB")
308
+
309
+ # ========= Core (image → generate) =========
310
+ def _prepare_image_tensor(pil_img, img_size=512):
311
+ tfm = image_transform(img_size=img_size)
312
+ tens = tfm(pil_img).unsqueeze(0).to(device, non_blocking=True) # [1,3,H,W]
313
+ return tens
314
+
315
+ def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads):
316
+ """
317
+ 1) Transform image
318
+ 2) model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
319
+ expected to return (gen_ids, gen_text, attentions)
320
+ 3) Build selector over generated words only
321
+ 4) Initial visualization -> (orig, overlay, heatmap, word HTML)
322
+ """
323
+ if pil_image is None:
324
+ # Return placeholders
325
+ blank = Image.new("RGB", (256, 256), "black")
326
+ return (
327
+ None, None, 1024, None, None,
328
+ gr.update(choices=[], value=None),
329
+ blank, # original
330
+ blank, # overlay
331
+ np.zeros((256, 256, 3), dtype=np.uint8), # heatmap RGB upscaled (placeholder)
332
+ "<div style='text-align:center;padding:20px;'>Upload or load an image first.</div>",
333
+ )
334
+
335
+ pixel_values = _prepare_image_tensor(pil_image, img_size=512)
336
+
337
+ with torch.no_grad():
338
+ gen_ids, gen_text, attentions = model.generate(
339
+ pixel_values=pixel_values,
340
+ max_new_tokens=int(max_new_tokens),
341
+ output_attentions=True
342
+ )
343
+
344
+ # Expect batch size 1
345
+ if isinstance(gen_ids, torch.Tensor):
346
+ gen_ids = gen_ids[0].tolist()
347
+ gen_ids = _strip_trailing_special(gen_ids)
348
+
349
+ words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids)
350
+
351
+ display_choices = [(w, i) for i, w in enumerate(words_gen)]
352
+ if not display_choices:
353
+ # No generated tokens; still show original and blank heatmap/overlay
354
+ blank_hm = np.zeros((32, 32), dtype=np.uint8)
355
+ hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST)
356
+ overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35)
357
+ return (
358
+ attentions, gen_ids, 1024, words_gen, gen_word2tok_rel,
359
+ gr.update(choices=[], value=None),
360
+ pil_image, # original
361
+ overlay, # overlay
362
+ np.array(hm_rgb), # heatmap RGB
363
+ "<div style='text-align:center;padding:20px;'>No generated tokens to visualize.</div>",
364
+ )
365
+
366
+ first_idx = 0
367
+ hm_rgb_init, overlay_init, html_init = update_visualization(
368
+ selected_gen_index=first_idx,
369
+ attentions=attentions,
370
+ gen_token_ids=gen_ids,
371
+ layer=layer,
372
+ head=head,
373
+ mean_layers=mean_layers,
374
+ mean_heads=mean_heads,
375
+ words_gen=words_gen,
376
+ gen_word2tok_rel=gen_word2tok_rel,
377
+ pil_image=pil_image,
378
+ )
379
+
380
+ return (
381
+ attentions, # state_attentions
382
+ gen_ids, # state_gen_token_ids
383
+ 1024, # state_img_token_len (fixed)
384
+ words_gen, # state_words_gen
385
+ gen_word2tok_rel, # state_gen_word2tok_rel
386
+ gr.update(choices=display_choices, value=first_idx),
387
+ pil_image, # original image view
388
+ overlay_init, # overlay (PIL)
389
+ hm_rgb_init, # heatmap RGB (np array or PIL)
390
+ html_init, # HTML words viz
391
+ )
392
+
393
+ def update_visualization(
394
+ selected_gen_index,
395
+ attentions,
396
+ gen_token_ids,
397
+ layer,
398
+ head,
399
+ mean_layers,
400
+ mean_heads,
401
+ words_gen,
402
+ gen_word2tok_rel,
403
+ pil_image: Optional[Image.Image] = None,
404
+ ):
405
+ """
406
+ Recompute visualization for the chosen GENERATED word:
407
+ - Extract attention vector for that generation step.
408
+ - Build 32×32 heatmap from first 1024 values (image tokens), colorize and upscale to original image size.
409
+ - Create overlay (original + heatmap with alpha).
410
+ - Build word HTML from the portion corresponding to generated tokens.
411
+ For step t, keys cover: 1024 image tokens + (t+1) generated tokens so far.
412
+ """
413
+ if selected_gen_index is None or attentions is None or gen_word2tok_rel is None:
414
+ blank = np.zeros((256, 256, 3), dtype=np.uint8)
415
+ return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Generate first.</div>"
416
+
417
+ gidx = int(selected_gen_index)
418
+ if not (0 <= gidx < len(gen_word2tok_rel)):
419
+ blank = np.zeros((256, 256, 3), dtype=np.uint8)
420
+ return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
421
+
422
+ step_index = int(gen_word2tok_rel[gidx]) # last token of that word (relative to generation)
423
+ if not attentions or step_index >= len(attentions):
424
+ blank = np.zeros((256, 256, 3), dtype=np.uint8)
425
+ return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>No attention for this step.</div>"
426
+
427
+ token_attn = get_attention_for_token_layer(
428
+ attentions,
429
+ token_index=step_index,
430
+ layer_index=int(layer),
431
+ head_index=int(head),
432
+ mean_across_layers=bool(mean_layers),
433
+ mean_across_heads=bool(mean_heads),
434
+ )
435
+
436
+ attn_vals = token_attn.detach().cpu().numpy()
437
+ if attn_vals.ndim == 2:
438
+ attn_vals = attn_vals[-1] # (k,) from (q,k)
439
+
440
+ # ---- Heatmap over 1024 image tokens (colorized and upscaled to original size) ----
441
+ heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32)
442
+ hm_rgb_pil = _colorize_heatmap(heatmap_u8)
443
+
444
+ # If original image not provided (should be), create a placeholder size
445
+ if pil_image is None:
446
+ pil_image = Image.new("RGB", (256, 256), "black")
447
+
448
+ hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST)
449
+ overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35)
450
+
451
+ # ---- Word-level viz over generated tokens only ----
452
+ k_len = int(attn_vals.shape[0])
453
+ observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024)))
454
+ total_gen = len(gen_token_ids)
455
+
456
+ gen_vec = np.zeros(total_gen, dtype=float)
457
+ if observed_gen > 0:
458
+ # slice generated part of attention vector
459
+ start = 1024
460
+ end = min(1024 + observed_gen, k_len)
461
+ gen_slice = attn_vals[start:end]
462
+ gen_vec[: len(gen_slice)] = gen_slice
463
+
464
+ selected_token_rel_idx = step_index
465
+
466
+ html_words = generate_word_visualization_gen_only(
467
+ words_gen=words_gen,
468
+ word_ends_rel=gen_word2tok_rel,
469
+ gen_attn_values=gen_vec,
470
+ selected_token_rel_idx=selected_token_rel_idx,
471
+ )
472
+
473
+ # Return (heatmap RGB, overlay, html)
474
+ return np.array(hm_rgb_pil_up), overlay_pil, html_words
475
+
476
+ def toggle_slider(is_mean):
477
+ return gr.update(interactive=not bool(is_mean))
478
+
479
+ # ========= Gradio UI =========
480
+ EXAMPLES_DIR = "examples"
481
+
482
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
483
+ gr.Markdown("# 🖼️→📝 Image-to-Text Attention Visualizer (three views + text)")
484
+ gr.Markdown(
485
+ "Upload an image or click **Load random sample**, generate text, then select a **generated word**. "
486
+ "Above: original image, overlay (original + attention), and heatmap (colored). "
487
+ "Below: word-level attention over generated text."
488
+ )
489
+
490
+ # States
491
+ state_attentions = gr.State(None) # tuple over generation steps
492
+ state_gen_token_ids = gr.State(None) # list[int]
493
+ state_img_token_len = gr.State(1024) # fixed
494
+ state_words_gen = gr.State(None) # list[str]
495
+ state_gen_word2tok_rel = gr.State(None) # list[int]
496
+ state_last_image = gr.State(None) # PIL image of last input
497
+
498
+ L, H = model_heads_layers()
499
+
500
+ with gr.Row():
501
+ with gr.Column(scale=1):
502
+ gr.Markdown("### 1) Image")
503
+ img_input = gr.Image(type="pil", label="Upload image", height=280)
504
+ btn_load_sample = gr.Button("Load random sample from /examples", variant="secondary")
505
+ sample_status = gr.Markdown("")
506
+
507
+ gr.Markdown("### 2) Generation")
508
+ slider_max_tokens = gr.Slider(5, 200, value=40, step=5, label="Max New Tokens")
509
+ btn_generate = gr.Button("Generate", variant="primary")
510
+
511
+ gr.Markdown("### 3) Attention")
512
+ check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
513
+ check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
514
+ slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=False)
515
+ slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=False)
516
+
517
+ with gr.Column(scale=3):
518
+ # Three views row
519
+ with gr.Row():
520
+ img_original_view = gr.Image(
521
+ value=None,
522
+ label="Original image",
523
+ image_mode="RGB",
524
+ height=256
525
+ )
526
+ img_overlay_view = gr.Image(
527
+ value=None,
528
+ label="Overlay (image + attention)",
529
+ image_mode="RGB",
530
+ height=256
531
+ )
532
+ heatmap_view = gr.Image(
533
+ value=None,
534
+ label="Heatmap (colored)",
535
+ image_mode="RGB",
536
+ height=256
537
+ )
538
+
539
+ # Word selector & HTML viz below
540
+ radio_word_selector = gr.Radio(
541
+ [], label="Select Generated Word",
542
+ info="Selector lists only generated words"
543
+ )
544
+ html_visualization = gr.HTML(
545
+ "<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
546
+ "Text attention visualization will appear here.</div>"
547
+ )
548
+
549
+ # Sample loader: always use `examples/`
550
+ def _load_sample_from_examples():
551
+ try:
552
+ files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")]
553
+ if not files:
554
+ return gr.update(), "No files in /examples."
555
+ fp = os.path.join(EXAMPLES_DIR, random.choice(files))
556
+ pil_img = pil_from_path(fp)
557
+ return gr.update(value=pil_img), f"Loaded sample: {os.path.basename(fp)}"
558
+ except Exception as e:
559
+ return gr.update(), f"Error loading sample: {e}"
560
+
561
+ btn_load_sample.click(
562
+ fn=_load_sample_from_examples,
563
+ inputs=[],
564
+ outputs=[img_input, sample_status]
565
+ )
566
+
567
+ # Generate
568
+ def _run_and_store(pil_image, *args):
569
+ out = run_generation(pil_image, *args)
570
+ # store the original image for later updates
571
+ return (*out, pil_image)
572
+
573
+ btn_generate.click(
574
+ fn=_run_and_store,
575
+ inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads],
576
+ outputs=[
577
+ state_attentions,
578
+ state_gen_token_ids,
579
+ state_img_token_len,
580
+ state_words_gen,
581
+ state_gen_word2tok_rel,
582
+ radio_word_selector,
583
+ img_original_view, # original
584
+ img_overlay_view, # overlay
585
+ heatmap_view, # heatmap
586
+ html_visualization, # words HTML
587
+ state_last_image, # store original PIL
588
+ ],
589
+ )
590
+
591
+ # Update viz on any control change
592
+ def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img):
593
+ hm_rgb, overlay, html = update_visualization(
594
+ selected_gen_index,
595
+ attn,
596
+ gen_ids,
597
+ lyr,
598
+ hed,
599
+ meanL,
600
+ meanH,
601
+ words,
602
+ word2tok,
603
+ pil_image=last_img
604
+ )
605
+ return overlay, hm_rgb, html
606
+
607
+ for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
608
+ control.change(
609
+ fn=_update_wrapper,
610
+ inputs=[
611
+ radio_word_selector,
612
+ state_attentions,
613
+ state_gen_token_ids,
614
+ slider_layer,
615
+ slider_head,
616
+ check_mean_layers,
617
+ check_mean_heads,
618
+ state_words_gen,
619
+ state_gen_word2tok_rel,
620
+ state_last_image,
621
+ ],
622
+ outputs=[img_overlay_view, heatmap_view, html_visualization],
623
+ )
624
+
625
+ # Toggle slider interactivity
626
+ check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
627
+ check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
628
+
629
+ if __name__ == "__main__":
630
+ print(f"Device: {device}")
631
+ demo.launch(debug=True)