dream2589632147 commited on
Commit
97608ef
·
verified ·
1 Parent(s): aa6d576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -3
app.py CHANGED
@@ -1,10 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # =========================================================
2
  # GRADIO UI
3
  # =========================================================
4
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
5
 
6
- # --- قمنا بإزالة البانر من هنا ---
7
-
8
  # --- PROFESSIONAL YOUTUBE EMBED SECTION ---
9
  gr.HTML("""
10
  <div style="background: linear-gradient(135deg, #b90000 0%, #ff0000 100%); color: white; padding: 25px; border-radius: 16px; text-align: center; margin-bottom: 25px; box-shadow: 0 10px 30px rgba(185, 0, 0, 0.3);">
@@ -91,7 +289,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
  cache_examples=False
92
  )
93
 
94
- # --- BOTTOM ADVERTISEMENT BANNER (تم نقله هنا) ---
95
  gr.HTML("""
96
  <div style="background: linear-gradient(90deg, #4f46e5, #9333ea); color: white; padding: 15px; border-radius: 10px; text-align: center; margin-top: 20px; box-shadow: 0 4px 15px rgba(0,0,0,0.1);">
97
  <div style="display: flex; align-items: center; justify-content: center; gap: 20px; flex-wrap: wrap;">
 
1
+ import os
2
+ import spaces
3
+ import torch
4
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
5
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
6
+ from diffusers.utils.export_utils import export_to_video
7
+ import gradio as gr
8
+ import tempfile
9
+ import numpy as np
10
+ from PIL import Image
11
+ import random
12
+ import gc
13
+ from torchao.quantization import quantize_
14
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
15
+ import aoti
16
+
17
+ # =========================================================
18
+ # MODEL CONFIGURATION
19
+ # =========================================================
20
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
22
+
23
+ MAX_DIM = 832
24
+ MIN_DIM = 480
25
+ SQUARE_DIM = 640
26
+ MULTIPLE_OF = 16
27
+ MAX_SEED = np.iinfo(np.int32).max
28
+
29
+ FIXED_FPS = 16
30
+ MIN_FRAMES_MODEL = 8
31
+ MAX_FRAMES_MODEL = 7720
32
+ MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
33
+ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
34
+
35
+ # =========================================================
36
+ # LOAD PIPELINE
37
+ # =========================================================
38
+ print("Loading pipeline...")
39
+ pipe = WanImageToVideoPipeline.from_pretrained(
40
+ MODEL_ID,
41
+ transformer=WanTransformer3DModel.from_pretrained(
42
+ MODEL_ID,
43
+ subfolder="transformer",
44
+ torch_dtype=torch.bfloat16,
45
+ device_map="cuda",
46
+ token=HF_TOKEN
47
+ ),
48
+ transformer_2=WanTransformer3DModel.from_pretrained(
49
+ MODEL_ID,
50
+ subfolder="transformer_2",
51
+ torch_dtype=torch.bfloat16,
52
+ device_map="cuda",
53
+ token=HF_TOKEN
54
+ ),
55
+ torch_dtype=torch.bfloat16,
56
+ ).to("cuda")
57
+
58
+ # =========================================================
59
+ # LOAD LORA ADAPTERS
60
+ # =========================================================
61
+ print("Loading LoRA adapters...")
62
+ pipe.load_lora_weights(
63
+ "Kijai/WanVideo_comfy",
64
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
65
+ adapter_name="lightx2v"
66
+ )
67
+ pipe.load_lora_weights(
68
+ "Kijai/WanVideo_comfy",
69
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
70
+ adapter_name="lightx2v_2",
71
+ load_into_transformer_2=True
72
+ )
73
+
74
+ pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
75
+ pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
76
+ pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
77
+ pipe.unload_lora_weights()
78
+
79
+ # =========================================================
80
+ # QUANTIZATION & AOT OPTIMIZATION
81
+ # =========================================================
82
+ print("Applying quantization...")
83
+ quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
84
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
85
+ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
86
+
87
+ print("Loading AOTI blocks...")
88
+ aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
89
+ aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
90
+
91
+ # =========================================================
92
+ # DEFAULT PROMPTS
93
+ # =========================================================
94
+ default_prompt_i2v = "Make this image come alive with dynamic, cinematic human motion. Create smooth, natural, lifelike animation with fluid transitions, expressive body movement, realistic physics, and elegant camera flow. Deliver a polished, high-quality motion style that feels immersive, artistic, and visually captivating."
95
+
96
+ default_negative_prompt = (
97
+ "low quality, worst quality, motion artifacts, unstable motion, jitter, frame jitter, wobbling limbs, motion distortion, inconsistent movement, robotic movement, animation-like motion, awkward transitions, incorrect body mechanics, unnatural posing, off-balance poses, broken motion paths, frozen frames, duplicated frames, frame skipping, warped motion, stretching artifacts bad anatomy, incorrect proportions, deformed body, twisted torso, broken joints, dislocated limbs, distorted neck, unnatural spine curvature, malformed hands, extra fingers, missing fingers, fused fingers, distorted legs, extra limbs, collapsed feet, floating feet, foot sliding, foot jitter, backward walking, unnatural gait blurry details, long exposure blur, ghosting, shadow trails, smearing, washed-out colors, overexposure, underexposure, excessive contrast, blown highlights, poorly rendered clothing, fabric glitches, texture warping, clothing merging with body, incorrect cloth physics ugly background, cluttered scene, crowded background, random objects, unwanted text, subtitles, logos, graffiti, grain, noise, static artifacts, compression noise, jpeg artifacts, image-like stillness, painting-like look, cartoon texture, low-resolution textures"
98
+ )
99
+
100
+ # =========================================================
101
+ # IMAGE RESIZING LOGIC
102
+ # =========================================================
103
+ def resize_image(image: Image.Image) -> Image.Image:
104
+ width, height = image.size
105
+ if width == height:
106
+ return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
107
+
108
+ aspect_ratio = width / height
109
+ MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
110
+ MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
111
+
112
+ image_to_resize = image
113
+ if aspect_ratio > MAX_ASPECT_RATIO:
114
+ crop_width = int(round(height * MAX_ASPECT_RATIO))
115
+ left = (width - crop_width) // 2
116
+ image_to_resize = image.crop((left, 0, left + crop_width, height))
117
+ elif aspect_ratio < MIN_ASPECT_RATIO:
118
+ crop_height = int(round(width / MIN_ASPECT_RATIO))
119
+ top = (height - crop_height) // 2
120
+ image_to_resize = image.crop((0, top, width, top + crop_height))
121
+
122
+ if width > height:
123
+ target_w = MAX_DIM
124
+ target_h = int(round(target_w / aspect_ratio))
125
+ else:
126
+ target_h = MAX_DIM
127
+ target_w = int(round(target_h * aspect_ratio))
128
+
129
+ final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
130
+ final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
131
+
132
+ final_w = max(MIN_DIM, min(MAX_DIM, final_w))
133
+ final_h = max(MIN_DIM, min(MAX_DIM, final_h))
134
+
135
+ return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
136
+
137
+ # =========================================================
138
+ # UTILITY FUNCTIONS
139
+ # =========================================================
140
+ def get_num_frames(duration_seconds: float):
141
+ return 1 + int(np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
142
+
143
+ def get_duration(
144
+ input_image, prompt, steps, negative_prompt,
145
+ duration_seconds, guidance_scale, guidance_scale_2,
146
+ seed, randomize_seed, progress,
147
+ ):
148
+ if input_image is None:
149
+ return 120
150
+ BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
151
+ BASE_STEP_DURATION = 15
152
+
153
+ width, height = resize_image(input_image).size
154
+ frames = get_num_frames(duration_seconds)
155
+ factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
156
+ step_duration = BASE_STEP_DURATION * factor ** 1.5
157
+ return 10 + int(steps) * step_duration
158
+
159
+ # =========================================================
160
+ # MAIN GENERATION FUNCTION
161
+ # =========================================================
162
+ @spaces.GPU(duration=get_duration)
163
+ def generate_video(
164
+ input_image,
165
+ prompt,
166
+ steps=4,
167
+ negative_prompt=default_negative_prompt,
168
+ duration_seconds=MAX_DURATION,
169
+ guidance_scale=1,
170
+ guidance_scale_2=1,
171
+ seed=42,
172
+ randomize_seed=False,
173
+ progress=gr.Progress(track_tqdm=True),
174
+ ):
175
+ if input_image is None:
176
+ raise gr.Error("Please upload an input image.")
177
+
178
+ num_frames = get_num_frames(duration_seconds)
179
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
180
+ resized_image = resize_image(input_image)
181
+
182
+ output_frames_list = pipe(
183
+ image=resized_image,
184
+ prompt=prompt,
185
+ negative_prompt=negative_prompt,
186
+ height=resized_image.height,
187
+ width=resized_image.width,
188
+ num_frames=num_frames,
189
+ guidance_scale=float(guidance_scale),
190
+ guidance_scale_2=float(guidance_scale_2),
191
+ num_inference_steps=int(steps),
192
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
193
+ ).frames[0]
194
+
195
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
196
+ video_path = tmpfile.name
197
+
198
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
199
+ return video_path, current_seed
200
+
201
  # =========================================================
202
  # GRADIO UI
203
  # =========================================================
204
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
205
 
 
 
206
  # --- PROFESSIONAL YOUTUBE EMBED SECTION ---
207
  gr.HTML("""
208
  <div style="background: linear-gradient(135deg, #b90000 0%, #ff0000 100%); color: white; padding: 25px; border-radius: 16px; text-align: center; margin-bottom: 25px; box-shadow: 0 10px 30px rgba(185, 0, 0, 0.3);">
 
289
  cache_examples=False
290
  )
291
 
292
+ # --- BOTTOM ADVERTISEMENT BANNER ---
293
  gr.HTML("""
294
  <div style="background: linear-gradient(90deg, #4f46e5, #9333ea); color: white; padding: 15px; border-radius: 10px; text-align: center; margin-top: 20px; box-shadow: 0 4px 15px rgba(0,0,0,0.1);">
295
  <div style="display: flex; align-items: center; justify-content: center; gap: 20px; flex-wrap: wrap;">