Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from model.clip_away import CLIPAway | |
| import cv2 | |
| import numpy as np | |
| import argparse | |
| # Load configuration and models | |
| config = OmegaConf.load("config/inference_config.yaml") | |
| sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
| "botp/stable-diffusion-v1-5-inpainting", torch_dtype=torch.float32 | |
| ) | |
| clipaway = CLIPAway( | |
| sd_pipe=sd_pipeline, | |
| image_encoder_path=config.image_encoder_path, | |
| ip_ckpt=config.ip_adapter_ckpt_path, | |
| alpha_clip_path=config.alpha_clip_ckpt_pth, | |
| config=config, | |
| alpha_clip_id=config.alpha_clip_id, | |
| device="cuda", | |
| num_tokens=4 | |
| ) | |
| def dilate_mask(mask, kernel_size=5, iterations=5): | |
| mask = mask.convert("L").resize((512, 512), Image.NEAREST) | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| mask = cv2.dilate(np.array(mask), kernel, iterations=iterations) | |
| return Image.fromarray(mask) | |
| def remove_obj(image, uploaded_mask, seed): | |
| image_pil = image.resize((512, 512), Image.LANCZOS) | |
| mask = dilate_mask(uploaded_mask) | |
| seed = int(seed) | |
| latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda") | |
| final_image = clipaway.generate( | |
| prompt=[""], scale=1, seed=seed, | |
| pil_image=[image_pil], alpha=[mask], strength=1, latents=latents | |
| )[0] | |
| return final_image | |
| # Define example data | |
| examples = [ | |
| ["gradio_examples/images/1.jpg", "gradio_examples/masks/1.png", 42], | |
| ["gradio_examples/images/2.jpg", "gradio_examples/masks/2.png", 42], | |
| ["gradio_examples/images/3.jpg", "gradio_examples/masks/3.png", 464], | |
| ] | |
| with gr.Blocks(theme="gradio/monochrome") as demo: | |
| gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>") | |
| gr.Markdown(""" | |
| <div style='display:flex; justify-content:center; align-items:center;'> | |
| <a href='https://arxiv.org/abs/2406.09368' style="margin-right:10px;">Paper</a> | | |
| <a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> | | |
| <a href='https://github.com/YigitEkin/CLIPAway' style="margin-left:10px;">GitHub</a> | |
| </div> | |
| """) | |
| gr.Markdown(""" | |
| This application allows you to remove objects from images using the CLIPAway method with diffusion models. | |
| To use this tool: | |
| 1. Upload an image. (NOTE: We expect a 512x512 image, if you upload a different size, it will be resized to 512x512 which can affect the results.) | |
| 2. Upload a pre-defined mask if you have one. (If you don't have a mask, and want to sketch one, | |
| we have provided a gradio demo in our github repository. <br/> Unfortunately, we cannot provide it here due to the compatibility issues with zerogpu.) | |
| 3. Set the seed for reproducibility (default is 42). | |
| 4. Click 'Remove Object' to process the image. | |
| 5. The result will be displayed on the right side. | |
| Note: The mask should be a binary image where the object to be removed is white and the background is black. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Upload Image and Sketch Mask", type="pil", image_mode="RGB") | |
| uploaded_mask = gr.Image(label="Upload Mask", type="pil", image_mode="L") | |
| seed_input = gr.Number(value=42, label="Seed") | |
| process_button = gr.Button("Remove Object") | |
| with gr.Column(): | |
| result_image = gr.Image(label="Result") | |
| process_button.click( | |
| fn=remove_obj, | |
| inputs=[image_input, uploaded_mask, seed_input], | |
| outputs=result_image | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_input, uploaded_mask, seed_input], | |
| outputs=result_image | |
| ) | |
| demo.launch() | |