CXR-Findings-AI / app.py
manu02's picture
Upload app.py
c98fd6f
raw
history blame
24.5 kB
# app.py
"""
🖼️→📝 Image-to-Text Attention Visualizer (Custom Model)
- Loads your custom model via create_complete_model()
- Accepts an image, applies your transform, then calls:
model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
- Selector lists ONLY generated words (no prompt tokens).
- Viewer (single row) shows:
(1) original image,
(2) original + colored attention heatmap overlay,
(3) heatmap alone (colored).
- Heatmap is built from the first 1024 image tokens (32×32), then upscaled to the image size.
- Text block below shows word-level attention over generated tokens (no return_offsets_mapping used).
- Fixes deprecations: Matplotlib colormap API & Pillow mode inference.
"""
import os
import re
import random
from typing import List, Tuple, Optional
import gradio as gr
import torch
import numpy as np
from PIL import Image
from safetensors.torch import load_model
# Optional: nicer colormap (Matplotlib >=3.7 API; no deprecation warnings)
try:
import matplotlib as mpl
_HAS_MPL = True
_COLORMAP = mpl.colormaps.get_cmap("magma")
except Exception:
_HAS_MPL = False
_COLORMAP = None
# ========= Your utilities & model =========
from utils.processing import image_transform, pil_from_path
from utils.complete_model import create_complete_model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_complete_model(device=DEVICE, attention_implementation="eager")
SAFETENSOR_PATH = "complete_model.safetensor"
try:
load_model(model, SAFETENSOR_PATH)
except Exception as e:
print(f"Error loading model: {e}, continuing with uninitialized weights.")
model.eval()
device = DEVICE
# --- Grab tokenizer from your model ---
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is None:
raise ValueError("Expected `model.tokenizer` to exist and be a HF-like tokenizer.")
# --- Fix PAD/EOS ambiguity (and resize embeddings if applicable) ---
needs_resize = False
pad_id = getattr(tokenizer, "pad_token_id", None)
eos_id = getattr(tokenizer, "eos_token_id", None)
if pad_id is None or (eos_id is not None and pad_id == eos_id):
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
needs_resize = True
# Try common resize hooks safely (only if your decoder actually uses tokenizer vocab)
if needs_resize:
resize_fns = [
getattr(getattr(model, "decoder", None), "resize_token_embeddings", None),
getattr(model, "resize_token_embeddings", None),
]
for fn in resize_fns:
if callable(fn):
try:
fn(len(tokenizer))
break
except Exception:
# If your model doesn't need resizing (separate vocab), it's fine.
pass
# ========= Regex for words (words + punctuation) =========
WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
# ========= Model metadata (for slider ranges) =========
def model_heads_layers():
def _get(obj, *names, default=None):
for n in names:
if obj is None:
return default
if hasattr(obj, n):
try:
return int(getattr(obj, n))
except Exception:
return default
return default
cfg_candidates = [
getattr(model, "config", None),
getattr(getattr(model, "decoder", None), "config", None),
getattr(getattr(model, "lm_head", None), "config", None),
]
L = H = None
for cfg in cfg_candidates:
if L is None:
L = _get(cfg, "num_hidden_layers", "n_layer", default=None)
if H is None:
H = _get(cfg, "num_attention_heads", "n_head", default=None)
if L is None: L = 12
if H is None: H = 12
return max(1, L), max(1, H)
# ========= Attention utils =========
def get_attention_for_token_layer(
attentions,
token_index,
layer_index,
batch_index=0,
head_index=0,
mean_across_layers=True,
mean_across_heads=True,
):
"""
`attentions`:
tuple length = #generated tokens
attentions[t] -> tuple over layers; each layer tensor is (batch, heads, q, k)
"""
token_attention = attentions[token_index]
if mean_across_layers:
layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
else:
layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
if mean_across_heads:
head_attention = batch_attention.mean(dim=0) # (q, k)
else:
head_attention = batch_attention[int(head_index)] # (q, k)
return head_attention.squeeze(0) # q==1 -> (k,)
# ========= Tokens → words mapping (no offset_mapping needed) =========
def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
"""
Works with slow/fast tokenizers. No return_offsets_mapping.
Steps:
1) detok token_ids
2) regex-split words and get their char-end positions
3) for each word-end (we), encode detok[:we] w/ add_special_tokens=False
last token index = len(prefix_ids) - 1
"""
if not token_ids:
return [], []
toks = tokenizer.convert_ids_to_tokens(token_ids)
detok = tokenizer.convert_tokens_to_string(toks)
matches = list(re.finditer(WORD_RE, detok))
words = [m.group(0) for m in matches]
ends = [m.span()[1] for m in matches] # char end (exclusive)
word2tok: List[int] = []
for we in ends:
prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False)
if not prefix_ids:
word2tok.append(0)
continue
last_idx = len(prefix_ids) - 1
last_idx = max(0, min(last_idx, len(token_ids) - 1))
word2tok.append(last_idx)
return words, word2tok
def _strip_trailing_special(ids: List[int]) -> List[int]:
specials = set(getattr(tokenizer, "all_special_ids", []) or [])
j = len(ids)
while j > 0 and ids[j - 1] in specials:
j -= 1
return ids[:j]
# ========= Visualization (word-level for generated text) =========
def generate_word_visualization_gen_only(
words_gen: List[str],
word_ends_rel: List[int],
gen_attn_values: np.ndarray,
selected_token_rel_idx: int,
) -> str:
"""
words_gen: generated words only
word_ends_rel: last-token indices of each generated word (relative to generation)
gen_attn_values: length == len(gen_token_ids), attention over generated tokens only
(zeros for future tokens padded at the end)
"""
if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0:
return (
"<div style='width:100%;'>"
" <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
" <div style='color:#ddd;'>No text attention values.</div>"
" </div>"
"</div>"
)
# compute word starts from ends (inclusive indexing)
starts = []
for i, end in enumerate(word_ends_rel):
if i == 0:
starts.append(0)
else:
starts.append(min(word_ends_rel[i - 1] + 1, end))
# sum attention per word
word_scores = []
T = len(gen_attn_values)
for i, end in enumerate(word_ends_rel):
start = starts[i]
if start > end:
start = end
s = max(0, min(start, T - 1))
e = max(0, min(end, T - 1))
if e < s:
s, e = e, s
word_scores.append(float(gen_attn_values[s:e + 1].sum()))
max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
# find selected word (contains selected token idx)
selected_word_idx = None
for i, end in enumerate(word_ends_rel):
if selected_token_rel_idx <= end:
selected_word_idx = i
break
if selected_word_idx is None and word_ends_rel:
selected_word_idx = len(word_ends_rel) - 1
spans = []
for i, w in enumerate(words_gen):
alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
bg = f"rgba(66,133,244,{alpha:.3f})"
border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
spans.append(
f"<span style='display:inline-block;background:{bg};border:{border};"
f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
f"{w}</span>"
)
return (
"<div style='width:100%;'>"
" <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
" <div style='white-space:normal;line-height:1.8;'>"
f" {''.join(spans)}"
" </div>"
" </div>"
"</div>"
)
# ========= Heatmap helpers for 1024 image tokens =========
def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray:
"""
attn_1d: (k,) attention over keys for a given generation step; first 1024 are image tokens.
Returns a (32, 32) uint8 grayscale array.
"""
# take first 1024 (image tokens); pad/truncate as needed
if attn_1d.shape[0] < img_token_len:
img_part = np.zeros(img_token_len, dtype=float)
img_part[: attn_1d.shape[0]] = attn_1d
else:
img_part = attn_1d[:img_token_len]
# normalize to [0,1]
mn, mx = float(img_part.min()), float(img_part.max())
denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0
norm = (img_part - mn) / denom
# return uint8 (0–255)
return (norm.reshape(side, side) * 255.0).astype(np.uint8)
def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image:
"""
Convert (H,W) uint8 grayscale to RGB heatmap using matplotlib (if available) or a simple fallback.
"""
if _HAS_MPL and _COLORMAP is not None:
colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8)
return Image.fromarray(colored) # Pillow infers RGB
else:
# Fallback: map grayscale to red-yellow (simple linear)
g = heatmap_u8.astype(np.float32) / 255.0
r = (g * 255.0).clip(0, 255).astype(np.uint8)
g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8)
b = np.zeros_like(r, dtype=np.uint8)
rgb = np.stack([r, g2, b], axis=-1)
return Image.fromarray(rgb) # Pillow infers RGB
def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
return img.resize(target_size, resample=Image.BILINEAR)
def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image:
"""
Blend heatmap over original. alpha in [0,1].
"""
if heatmap_rgb.size != orig.size:
heatmap_rgb = _resize_like(heatmap_rgb, orig.size)
base = orig.convert("RGBA")
overlay = heatmap_rgb.convert("RGBA")
# set global alpha
r, g, b = overlay.split()[:3]
a = Image.new("L", overlay.size, int(alpha * 255))
overlay = Image.merge("RGBA", (r, g, b, a))
return Image.alpha_composite(base, overlay).convert("RGB")
# ========= Core (image → generate) =========
def _prepare_image_tensor(pil_img, img_size=512):
tfm = image_transform(img_size=img_size)
tens = tfm(pil_img).unsqueeze(0).to(device, non_blocking=True) # [1,3,H,W]
return tens
def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads):
"""
1) Transform image
2) model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
expected to return (gen_ids, gen_text, attentions)
3) Build selector over generated words only
4) Initial visualization -> (orig, overlay, heatmap, word HTML)
"""
if pil_image is None:
# Return placeholders
blank = Image.new("RGB", (256, 256), "black")
return (
None, None, 1024, None, None,
gr.update(choices=[], value=None),
blank, # original
blank, # overlay
np.zeros((256, 256, 3), dtype=np.uint8), # heatmap RGB upscaled (placeholder)
"<div style='text-align:center;padding:20px;'>Upload or load an image first.</div>",
)
pixel_values = _prepare_image_tensor(pil_image, img_size=512)
with torch.no_grad():
gen_ids, gen_text, attentions = model.generate(
pixel_values=pixel_values,
max_new_tokens=int(max_new_tokens),
output_attentions=True
)
# Expect batch size 1
if isinstance(gen_ids, torch.Tensor):
gen_ids = gen_ids[0].tolist()
gen_ids = _strip_trailing_special(gen_ids)
words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids)
display_choices = [(w, i) for i, w in enumerate(words_gen)]
if not display_choices:
# No generated tokens; still show original and blank heatmap/overlay
blank_hm = np.zeros((32, 32), dtype=np.uint8)
hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST)
overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35)
return (
attentions, gen_ids, 1024, words_gen, gen_word2tok_rel,
gr.update(choices=[], value=None),
pil_image, # original
overlay, # overlay
np.array(hm_rgb), # heatmap RGB
"<div style='text-align:center;padding:20px;'>No generated tokens to visualize.</div>",
)
first_idx = 0
hm_rgb_init, overlay_init, html_init = update_visualization(
selected_gen_index=first_idx,
attentions=attentions,
gen_token_ids=gen_ids,
layer=layer,
head=head,
mean_layers=mean_layers,
mean_heads=mean_heads,
words_gen=words_gen,
gen_word2tok_rel=gen_word2tok_rel,
pil_image=pil_image,
)
return (
attentions, # state_attentions
gen_ids, # state_gen_token_ids
1024, # state_img_token_len (fixed)
words_gen, # state_words_gen
gen_word2tok_rel, # state_gen_word2tok_rel
gr.update(choices=display_choices, value=first_idx),
pil_image, # original image view
overlay_init, # overlay (PIL)
hm_rgb_init, # heatmap RGB (np array or PIL)
html_init, # HTML words viz
)
def update_visualization(
selected_gen_index,
attentions,
gen_token_ids,
layer,
head,
mean_layers,
mean_heads,
words_gen,
gen_word2tok_rel,
pil_image: Optional[Image.Image] = None,
):
"""
Recompute visualization for the chosen GENERATED word:
- Extract attention vector for that generation step.
- Build 32×32 heatmap from first 1024 values (image tokens), colorize and upscale to original image size.
- Create overlay (original + heatmap with alpha).
- Build word HTML from the portion corresponding to generated tokens.
For step t, keys cover: 1024 image tokens + (t+1) generated tokens so far.
"""
if selected_gen_index is None or attentions is None or gen_word2tok_rel is None:
blank = np.zeros((256, 256, 3), dtype=np.uint8)
return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Generate first.</div>"
gidx = int(selected_gen_index)
if not (0 <= gidx < len(gen_word2tok_rel)):
blank = np.zeros((256, 256, 3), dtype=np.uint8)
return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
step_index = int(gen_word2tok_rel[gidx]) # last token of that word (relative to generation)
if not attentions or step_index >= len(attentions):
blank = np.zeros((256, 256, 3), dtype=np.uint8)
return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>No attention for this step.</div>"
token_attn = get_attention_for_token_layer(
attentions,
token_index=step_index,
layer_index=int(layer),
head_index=int(head),
mean_across_layers=bool(mean_layers),
mean_across_heads=bool(mean_heads),
)
attn_vals = token_attn.detach().cpu().numpy()
if attn_vals.ndim == 2:
attn_vals = attn_vals[-1] # (k,) from (q,k)
# ---- Heatmap over 1024 image tokens (colorized and upscaled to original size) ----
heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32)
hm_rgb_pil = _colorize_heatmap(heatmap_u8)
# If original image not provided (should be), create a placeholder size
if pil_image is None:
pil_image = Image.new("RGB", (256, 256), "black")
hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST)
overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35)
# ---- Word-level viz over generated tokens only ----
k_len = int(attn_vals.shape[0])
observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024)))
total_gen = len(gen_token_ids)
gen_vec = np.zeros(total_gen, dtype=float)
if observed_gen > 0:
# slice generated part of attention vector
start = 1024
end = min(1024 + observed_gen, k_len)
gen_slice = attn_vals[start:end]
gen_vec[: len(gen_slice)] = gen_slice
selected_token_rel_idx = step_index
html_words = generate_word_visualization_gen_only(
words_gen=words_gen,
word_ends_rel=gen_word2tok_rel,
gen_attn_values=gen_vec,
selected_token_rel_idx=selected_token_rel_idx,
)
# Return (heatmap RGB, overlay, html)
return np.array(hm_rgb_pil_up), overlay_pil, html_words
def toggle_slider(is_mean):
return gr.update(interactive=not bool(is_mean))
# ========= Gradio UI =========
EXAMPLES_DIR = "examples"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️→📝 Image-to-Text Attention Visualizer (three views + text)")
gr.Markdown(
"Upload an image or click **Load random sample**, generate text, then select a **generated word**. "
"Above: original image, overlay (original + attention), and heatmap (colored). "
"Below: word-level attention over generated text."
)
# States
state_attentions = gr.State(None) # tuple over generation steps
state_gen_token_ids = gr.State(None) # list[int]
state_img_token_len = gr.State(1024) # fixed
state_words_gen = gr.State(None) # list[str]
state_gen_word2tok_rel = gr.State(None) # list[int]
state_last_image = gr.State(None) # PIL image of last input
L, H = model_heads_layers()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1) Image")
img_input = gr.Image(type="pil", label="Upload image", height=280)
btn_load_sample = gr.Button("Load random sample from /examples", variant="secondary")
sample_status = gr.Markdown("")
gr.Markdown("### 2) Generation")
slider_max_tokens = gr.Slider(5, 200, value=40, step=5, label="Max New Tokens")
btn_generate = gr.Button("Generate", variant="primary")
gr.Markdown("### 3) Attention")
check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=False)
slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=False)
with gr.Column(scale=3):
# Three views row
with gr.Row():
img_original_view = gr.Image(
value=None,
label="Original image",
image_mode="RGB",
height=256
)
img_overlay_view = gr.Image(
value=None,
label="Overlay (image + attention)",
image_mode="RGB",
height=256
)
heatmap_view = gr.Image(
value=None,
label="Heatmap (colored)",
image_mode="RGB",
height=256
)
# Word selector & HTML viz below
radio_word_selector = gr.Radio(
[], label="Select Generated Word",
info="Selector lists only generated words"
)
html_visualization = gr.HTML(
"<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
"Text attention visualization will appear here.</div>"
)
# Sample loader: always use `examples/`
def _load_sample_from_examples():
try:
files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")]
if not files:
return gr.update(), "No files in /examples."
fp = os.path.join(EXAMPLES_DIR, random.choice(files))
pil_img = pil_from_path(fp)
return gr.update(value=pil_img), f"Loaded sample: {os.path.basename(fp)}"
except Exception as e:
return gr.update(), f"Error loading sample: {e}"
btn_load_sample.click(
fn=_load_sample_from_examples,
inputs=[],
outputs=[img_input, sample_status]
)
# Generate
def _run_and_store(pil_image, *args):
out = run_generation(pil_image, *args)
# store the original image for later updates
return (*out, pil_image)
btn_generate.click(
fn=_run_and_store,
inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads],
outputs=[
state_attentions,
state_gen_token_ids,
state_img_token_len,
state_words_gen,
state_gen_word2tok_rel,
radio_word_selector,
img_original_view, # original
img_overlay_view, # overlay
heatmap_view, # heatmap
html_visualization, # words HTML
state_last_image, # store original PIL
],
)
# Update viz on any control change
def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img):
hm_rgb, overlay, html = update_visualization(
selected_gen_index,
attn,
gen_ids,
lyr,
hed,
meanL,
meanH,
words,
word2tok,
pil_image=last_img
)
return overlay, hm_rgb, html
for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
control.change(
fn=_update_wrapper,
inputs=[
radio_word_selector,
state_attentions,
state_gen_token_ids,
slider_layer,
slider_head,
check_mean_layers,
check_mean_heads,
state_words_gen,
state_gen_word2tok_rel,
state_last_image,
],
outputs=[img_overlay_view, heatmap_view, html_visualization],
)
# Toggle slider interactivity
check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
if __name__ == "__main__":
print(f"Device: {device}")
demo.launch(debug=True)