| | |
| | |
| | |
| | |
| |
|
| | import cv2 |
| | import numpy as np |
| | from PIL import Image, ImageDraw |
| | import torch |
| | import matplotlib.pyplot as plt |
| | from skimage import filters |
| | from IPython.display import display |
| |
|
| | def gaussian_blur(heatmap, kernel_size=7): |
| | |
| | heatmap = heatmap.cpu().numpy() |
| | heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), 0) |
| | heatmap = torch.tensor(heatmap) |
| | |
| | return heatmap |
| |
|
| | def show_cam_on_image(img, mask): |
| | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
| | heatmap = np.float32(heatmap) / 255 |
| | cam = heatmap + np.float32(img) |
| | cam = cam / np.max(cam) |
| | return cam |
| |
|
| | def show_image_and_heatmap(heatmap: torch.Tensor, image: Image.Image, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3): |
| | image = image.resize((relevnace_res, relevnace_res)) |
| | image = np.array(image) |
| | image = (image - image.min()) / (image.max() - image.min()) |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1]) |
| | heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation) |
| | heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
| | heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu() |
| |
|
| | vis = show_cam_on_image(image, heatmap) |
| | vis = np.uint8(255 * vis) |
| | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
| |
|
| | vis = vis.astype(np.uint8) |
| | vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res)) |
| |
|
| | return vis |
| |
|
| | def show_only_heatmap(heatmap: torch.Tensor, relevnace_res: int = 256, interpolation: str = 'bilinear', gassussian_kernel_size: int = 3): |
| | |
| | |
| |
|
| | heatmap = heatmap.reshape(1, 1, heatmap.shape[-1], heatmap.shape[-1]) |
| | heatmap = torch.nn.functional.interpolate(heatmap, size=relevnace_res, mode=interpolation) |
| | heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
| | heatmap = heatmap.reshape(relevnace_res, relevnace_res).cpu() |
| |
|
| | vis = heatmap |
| | vis = np.uint8(255 * vis) |
| |
|
| | |
| | vis = cv2.cvtColor(np.array(vis), cv2.COLOR_GRAY2BGR) |
| |
|
| | vis = Image.fromarray(vis).resize((relevnace_res, relevnace_res)) |
| |
|
| | return vis |
| |
|
| | def visualize_tokens_attentions(attention, tokens, image, heatmap_interpolation="nearest", show_on_image=True): |
| | |
| | |
| | token_vis = [] |
| | for j, token in enumerate(tokens): |
| | if j >= attention.shape[0]: |
| | break |
| | |
| | if show_on_image: |
| | vis = show_image_and_heatmap(attention[j], image, relevnace_res=512, interpolation=heatmap_interpolation) |
| | else: |
| | vis = show_only_heatmap(attention[j], relevnace_res=512, interpolation=heatmap_interpolation) |
| | |
| | token_vis.append((token, vis)) |
| |
|
| | |
| | K = 4 |
| | n_rows = (len(token_vis) + K - 1) // K |
| | fig, axs = plt.subplots(n_rows, K, figsize=(K*5, n_rows*5)) |
| |
|
| | for i, (token, vis) in enumerate(token_vis): |
| | row, col = divmod(i, K) |
| | if n_rows > 1: |
| | ax = axs[row, col] |
| | elif K > 1: |
| | ax = axs[col] |
| | else: |
| | ax = axs |
| |
|
| | ax.imshow(vis) |
| | ax.set_title(token) |
| | ax.axis("off") |
| |
|
| | |
| | for j in range(i + 1, n_rows * K): |
| | row, col = divmod(j, K) |
| | if n_rows > 1: |
| | axs[row, col].axis('off') |
| | elif K > 1: |
| | axs[col].axis('off') |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | return fig |
| |
|
| | def show_images(images, titles=None, size=1024, max_row_length=5, figsize=None, col_height=10, save_path=None): |
| | if isinstance(images, Image.Image): |
| | images = [images] |
| |
|
| | if len(images) == 1: |
| | img = images[0] |
| | img = img.resize((size, size)) |
| | plt.imshow(img) |
| | plt.axis('off') |
| |
|
| | if titles is not None: |
| | plt.title(titles[0]) |
| | |
| | if save_path: |
| | plt.savefig(save_path, bbox_inches='tight', dpi=150) |
| | |
| | plt.show() |
| | else: |
| | images = [img.resize((size, size)) for img in images] |
| |
|
| | |
| | if titles is not None: |
| | assert len(images) == len(titles), "Number of titles should match the number of images" |
| |
|
| | n_images = len(images) |
| | n_cols = min(n_images, max_row_length) |
| | n_rows = (n_images + n_cols - 1) // n_cols |
| |
|
| | if figsize is None: |
| | figsize=(n_cols * col_height, n_rows * col_height) |
| |
|
| | fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize) |
| | axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs] |
| |
|
| | |
| | for i, img in enumerate(images): |
| | axs[i].imshow(img) |
| | if titles is not None: |
| | axs[i].set_title(titles[i]) |
| | axs[i].axis("off") |
| |
|
| | |
| | for ax in axs[len(images):]: |
| | ax.axis("off") |
| |
|
| | if save_path: |
| | plt.savefig(save_path, bbox_inches='tight', dpi=150) |
| | |
| | plt.show() |
| |
|
| | def show_tensors(tensors, titles=None, size=None, max_row_length=5): |
| | |
| | if size is not None: |
| | tensors = [torch.nn.functional.interpolate(t.unsqueeze(0).unsqueeze(0), size=(size, size), mode='bilinear').squeeze() for t in tensors] |
| |
|
| | if len(tensors) == 1: |
| | plt.imshow(tensors[0].cpu().numpy()) |
| | plt.axis('off') |
| |
|
| | if titles is not None: |
| | plt.title(titles[0]) |
| | |
| | plt.show() |
| | else: |
| | |
| | if titles is not None: |
| | assert len(tensors) == len(titles), "Number of titles should match the number of images" |
| |
|
| | n_tensors = len(tensors) |
| | n_cols = min(n_tensors, max_row_length) |
| | n_rows = (n_tensors + n_cols - 1) // n_cols |
| |
|
| | fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 10, n_rows * 10)) |
| | axs = axs.flatten() if isinstance(axs, np.ndarray) else [axs] |
| |
|
| | for i, tensor in enumerate(tensors): |
| | axs[i].imshow(tensor.cpu().numpy()) |
| | if titles is not None: |
| | axs[i].set_title(titles[i]) |
| | axs[i].axis("off") |
| |
|
| | for ax in axs[len(tensors):]: |
| | ax.axis("off") |
| | |
| | plt.show() |
| |
|
| | def draw_bboxes_on_image(image, bboxes, color="red", thickness=2): |
| | image = image.copy() |
| | draw = ImageDraw.Draw(image) |
| | for bbox in bboxes: |
| | draw.rectangle(bbox, outline=color, width=thickness) |
| | return image |
| |
|
| | def draw_points_on_pil_image(pil_image, point_coords, point_color="red", radius=5): |
| | """ |
| | Draw points (circles) on a PIL image and return the modified image. |
| | |
| | :param pil_image: PIL Image (e.g., sam_masked_image) |
| | :param point_coords: An array-like of shape (N, 2), with x,y coordinates |
| | :param point_color: Color of the point (default 'red') |
| | :param radius: Radius of the drawn circles |
| | :return: PIL Image with points drawn |
| | """ |
| | |
| | out_img = pil_image.copy() |
| | draw = ImageDraw.Draw(out_img) |
| | |
| | |
| | for x, y in point_coords: |
| | |
| | left_up_point = (x - radius, y - radius) |
| | right_down_point = (x + radius, y + radius) |
| | |
| | draw.ellipse([left_up_point, right_down_point], fill=point_color, outline=point_color) |
| | |
| | return out_img |