Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| from PIL import Image | |
| from datetime import datetime | |
| from morph_attn import DiffMorpherPipeline | |
| from lora_utils import train_lora | |
| LENGTH=450 | |
| def train_lora_interface( | |
| image, | |
| prompt, | |
| model_path, | |
| output_path, | |
| lora_steps, | |
| lora_rank, | |
| lora_lr, | |
| num | |
| ): | |
| os.makedirs(output_path, exist_ok=True) | |
| train_lora(image, prompt, output_path, model_path, | |
| lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_{num}.ckpt", progress=gr.Progress()) | |
| return f"Train LoRA {'A' if num == 0 else 'B'} Done!" | |
| def run_diffmorpher( | |
| image_0, | |
| image_1, | |
| prompt_0, | |
| prompt_1, | |
| model_path, | |
| lora_mode, | |
| lamb, | |
| use_adain, | |
| use_reschedule, | |
| num_frames, | |
| fps, | |
| load_lora_path_0, | |
| load_lora_path_1, | |
| output_path | |
| ): | |
| run_id = datetime.now().strftime("%H%M") + "_" + datetime.now().strftime("%Y%m%d") | |
| os.makedirs(output_path, exist_ok=True) | |
| morpher_pipeline = DiffMorpherPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cuda") | |
| if lora_mode == "Fix LoRA A": | |
| fix_lora = 0 | |
| elif lora_mode == "Fix LoRA B": | |
| fix_lora = 1 | |
| else: | |
| fix_lora = None | |
| if not load_lora_path_0: | |
| load_lora_path_0 = f"{output_path}/lora_0.ckpt" | |
| if not load_lora_path_1: | |
| load_lora_path_1 = f"{output_path}/lora_1.ckpt" | |
| images = morpher_pipeline( | |
| img_0=image_0, | |
| img_1=image_1, | |
| prompt_0=prompt_0, | |
| prompt_1=prompt_1, | |
| load_lora_path_0=load_lora_path_0, | |
| load_lora_path_1=load_lora_path_1, | |
| lamb=lamb, | |
| use_adain=use_adain, | |
| use_reschedule=use_reschedule, | |
| num_frames=num_frames, | |
| fix_lora=fix_lora, | |
| progress=gr.Progress() | |
| ) | |
| video_path = f"{output_path}/{run_id}.mp4" | |
| video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (512, 512)) | |
| for i, image in enumerate(images): | |
| # image.save(f"{output_path}/{i}.png") | |
| video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) | |
| video.release() | |
| cv2.destroyAllWindows() | |
| return gr.Video(value=video_path, format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False) | |
| def run_all( | |
| image_0, | |
| image_1, | |
| prompt_0, | |
| prompt_1, | |
| model_path, | |
| lora_mode, | |
| lamb, | |
| use_adain, | |
| use_reschedule, | |
| num_frames, | |
| fps, | |
| load_lora_path_0, | |
| load_lora_path_1, | |
| output_path, | |
| lora_steps, | |
| lora_rank, | |
| lora_lr | |
| ): | |
| os.makedirs(output_path, exist_ok=True) | |
| train_lora(image_0, prompt_0, output_path, model_path, | |
| lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_0.ckpt", progress=gr.Progress()) | |
| train_lora(image_1, prompt_1, output_path, model_path, | |
| lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_1.ckpt", progress=gr.Progress()) | |
| return run_diffmorpher( | |
| image_0, | |
| image_1, | |
| prompt_0, | |
| prompt_1, | |
| model_path, | |
| lora_mode, | |
| lamb, | |
| use_adain, | |
| use_reschedule, | |
| num_frames, | |
| fps, | |
| load_lora_path_0, | |
| load_lora_path_1, | |
| output_path | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| # Official Implementation of [DiffMorpher](https://kevin-thu.github.io/DiffMorpher_page/) | |
| """) | |
| original_image_0, original_image_1 = gr.State(Image.open("Musk.jpg").convert("RGB").resize((512,512), Image.BILINEAR)), gr.State(Image.open("Feifei.jpg").convert("RGB").resize((512,512), Image.BILINEAR)) | |
| # key_points_0, key_points_1 = gr.State([]), gr.State([]) | |
| # to_change_points = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img_0 = gr.Image(type="numpy", label="Input image A", value="Musk.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True) | |
| prompt_0 = gr.Textbox(label="Prompt for image A", value="a photo of a man's face", interactive=True) | |
| with gr.Row(): | |
| train_lora_0_button = gr.Button("Train LoRA A") | |
| train_lora_1_button = gr.Button("Train LoRA B") | |
| # show_correspond_button = gr.Button("Show correspondence points") | |
| with gr.Column(): | |
| input_img_1 = gr.Image(type="numpy", label="Input image B ", value="Feifei.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True) | |
| prompt_1 = gr.Textbox(label="Prompt for image B", value="a photo of a woman's face", interactive=True) | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear All") | |
| run_button = gr.Button("Run w/o LoRA training") | |
| with gr.Column(): | |
| output_video = gr.Video(format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False) | |
| lora_progress_bar = gr.Textbox(label="Display LoRA training progress", interactive=False) | |
| run_all_button = gr.Button("Run!") | |
| # with gr.Column(): | |
| # output_video = gr.Video(label="Output video", show_label=True, height=LENGTH, width=LENGTH) | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| ### Usage: | |
| 1. Upload two images (with correspondence) and fill out the prompts. | |
| (It's recommended to change `[Output path]` accordingly.) | |
| 2. Click **"Run!"** | |
| Or: | |
| 1. Upload two images (with correspondence) and fill out the prompts. | |
| 2. Click the **"Train LoRA A/B"** button to fit two LoRAs for two images respectively. <br> | |
| If you have trained LoRA A or LoRA B before, you can skip the step and fill the specific LoRA path in LoRA settings. <br> | |
| Trained LoRAs are saved to `[Output Path]/lora_0.ckpt` and `[Output Path]/lora_1.ckpt` by default. | |
| 3. You might also change the settings below. | |
| 4. Click **"Run w/o LoRA training"** | |
| ### Note: | |
| 1. **Try restarting the space if you got an error.** (This is because the storage is limited now.) | |
| 2. Besides morphing, you can also try animations to make smooth videos too. | |
| 3. To speed up the generation process, you can **ruduce the number of frames** or **turn off "Use Reschedule"**. | |
| 4. You can try the influence of different prompts. It seems that using the same prompts or aligned prompts works better. | |
| ### Have fun! | |
| """) | |
| with gr.Accordion(label="Algorithm Parameters"): | |
| with gr.Tab("Basic Settings"): | |
| with gr.Row(): | |
| # local_models_dir = 'local_pretrained_models' | |
| # local_models_choice = \ | |
| # [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] | |
| model_path = gr.Text(value="stabilityai/stable-diffusion-2-1-base", | |
| label="Diffusion Model Path", interactive=True | |
| ) | |
| lamb = gr.Slider(value=0.6, minimum=0, maximum=1, step=0.1, label="Lambda for attention replacement", interactive=True) | |
| lora_mode = gr.Dropdown(value="LoRA Interp", | |
| label="LoRA Interp. or Fix LoRA", | |
| choices=["LoRA Interp", "Fix LoRA A", "Fix LoRA B"], | |
| interactive=True | |
| ) | |
| use_adain = gr.Checkbox(value=True, label="Use AdaIN", interactive=True) | |
| use_reschedule = gr.Checkbox(value=True, label="Use Reschedule", interactive=True) | |
| with gr.Row(): | |
| num_frames = gr.Number(value=16, minimum=0, label="Number of Frames", precision=0, interactive=True) | |
| fps = gr.Number(value=8, minimum=0, label="FPS (Frame rate)", precision=0, interactive=True) | |
| output_path = gr.Text(value="./results", label="Output Path", interactive=True) | |
| with gr.Tab("LoRA Settings"): | |
| with gr.Row(): | |
| lora_steps = gr.Number(value=200, label="LoRA training steps", precision=0, interactive=True) | |
| lora_lr = gr.Number(value=0.0002, label="LoRA learning rate", interactive=True) | |
| lora_rank = gr.Number(value=16, label="LoRA rank", precision=0, interactive=True) | |
| # save_lora_dir = gr.Text(value="./lora", label="LoRA model save path", interactive=True) | |
| load_lora_path_0 = gr.Text(value="", label="LoRA model load path for image A", interactive=True) | |
| load_lora_path_1 = gr.Text(value="", label="LoRA model load path for image B", interactive=True) | |
| def store_img(img): | |
| image = Image.fromarray(img).convert("RGB").resize((512,512), Image.BILINEAR) | |
| # resize the input to 512x512 | |
| # image = image.resize((512,512), Image.BILINEAR) | |
| # image = np.array(image) | |
| # when new image is uploaded, `selected_points` should be empty | |
| return image | |
| input_img_0.upload( | |
| store_img, | |
| [input_img_0], | |
| [original_image_0] | |
| ) | |
| input_img_1.upload( | |
| store_img, | |
| [input_img_1], | |
| [original_image_1] | |
| ) | |
| def clear(LENGTH): | |
| return gr.Image.update(value=None, width=LENGTH, height=LENGTH), \ | |
| gr.Image.update(value=None, width=LENGTH, height=LENGTH), \ | |
| None, None, None, None | |
| clear_button.click( | |
| clear, | |
| [gr.Number(value=LENGTH, visible=False, precision=0)], | |
| [input_img_0, input_img_1, original_image_0, original_image_1, prompt_0, prompt_1] | |
| ) | |
| train_lora_0_button.click( | |
| train_lora_interface, | |
| [ | |
| original_image_0, | |
| prompt_0, | |
| model_path, | |
| output_path, | |
| lora_steps, | |
| lora_rank, | |
| lora_lr, | |
| gr.Number(value=0, visible=False, precision=0) | |
| ], | |
| [lora_progress_bar] | |
| ) | |
| train_lora_1_button.click( | |
| train_lora_interface, | |
| [ | |
| original_image_1, | |
| prompt_1, | |
| model_path, | |
| output_path, | |
| lora_steps, | |
| lora_rank, | |
| lora_lr, | |
| gr.Number(value=1, visible=False, precision=0) | |
| ], | |
| [lora_progress_bar] | |
| ) | |
| run_button.click( | |
| run_diffmorpher, | |
| [ | |
| original_image_0, | |
| original_image_1, | |
| prompt_0, | |
| prompt_1, | |
| model_path, | |
| lora_mode, | |
| lamb, | |
| use_adain, | |
| use_reschedule, | |
| num_frames, | |
| fps, | |
| load_lora_path_0, | |
| load_lora_path_1, | |
| output_path | |
| ], | |
| [output_video] | |
| ) | |
| run_all_button.click( | |
| run_all, | |
| [ | |
| original_image_0, | |
| original_image_1, | |
| prompt_0, | |
| prompt_1, | |
| model_path, | |
| lora_mode, | |
| lamb, | |
| use_adain, | |
| use_reschedule, | |
| num_frames, | |
| fps, | |
| load_lora_path_0, | |
| load_lora_path_1, | |
| output_path, | |
| lora_steps, | |
| lora_rank, | |
| lora_lr | |
| ], | |
| [output_video] | |
| ) | |
| demo.queue().launch(debug=True) |