Spaces:
Runtime error
Runtime error
Antoni Bigata
commited on
Commit
Β·
2dff4e4
1
Parent(s):
cf0da47
requirements
Browse files
app.py
CHANGED
|
@@ -186,54 +186,17 @@ DEFAULT_AUDIO_PATH = os.path.join(
|
|
| 186 |
# landmarks_extractor,
|
| 187 |
# ) = load_all_models()
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
print("Successfully compiled vae_model in FP16")
|
| 196 |
-
except Exception as e:
|
| 197 |
-
print(f"Warning: Failed to compile vae_model: {e}")
|
| 198 |
-
|
| 199 |
-
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
| 200 |
-
hubert_model = hubert_model.half() # Convert to half precision
|
| 201 |
-
try:
|
| 202 |
-
hubert_model = torch.compile(hubert_model)
|
| 203 |
-
print("Successfully compiled hubert_model in FP16")
|
| 204 |
-
except Exception as e:
|
| 205 |
-
print(f"Warning: Failed to compile hubert_model: {e}")
|
| 206 |
-
|
| 207 |
-
wavlm_model = WavLM_wrapper(
|
| 208 |
-
model_size="Base+",
|
| 209 |
-
feed_as_frames=False,
|
| 210 |
-
merge_type="None",
|
| 211 |
-
model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
|
| 212 |
-
).cuda()
|
| 213 |
-
|
| 214 |
-
wavlm_model = wavlm_model.half() # Convert to half precision
|
| 215 |
-
try:
|
| 216 |
-
wavlm_model = torch.compile(wavlm_model)
|
| 217 |
-
print("Successfully compiled wavlm_model in FP16")
|
| 218 |
-
except Exception as e:
|
| 219 |
-
print(f"Warning: Failed to compile wavlm_model: {e}")
|
| 220 |
-
|
| 221 |
-
landmarks_extractor = LandmarksExtractor()
|
| 222 |
-
keyframe_model = load_model(
|
| 223 |
-
config="keyframe.yaml",
|
| 224 |
-
ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
|
| 225 |
-
)
|
| 226 |
-
interpolation_model = load_model(
|
| 227 |
-
config="interpolation.yaml",
|
| 228 |
-
ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
|
| 229 |
-
)
|
| 230 |
-
keyframe_model.en_and_decode_n_samples_a_time = 2
|
| 231 |
-
interpolation_model.en_and_decode_n_samples_a_time = 2
|
| 232 |
|
| 233 |
|
| 234 |
@spaces.GPU(duration=60)
|
| 235 |
@torch.no_grad()
|
| 236 |
-
def compute_video_embedding(video_reader, min_len):
|
| 237 |
"""Compute embeddings from video"""
|
| 238 |
|
| 239 |
total_frames = min_len
|
|
@@ -283,7 +246,7 @@ def compute_video_embedding(video_reader, min_len):
|
|
| 283 |
|
| 284 |
@spaces.GPU(duration=120)
|
| 285 |
@torch.no_grad()
|
| 286 |
-
def compute_hubert_embedding(raw_audio):
|
| 287 |
"""Compute embeddings from audio"""
|
| 288 |
print(f"Computing audio embedding from {raw_audio.shape}")
|
| 289 |
|
|
@@ -330,7 +293,7 @@ def compute_hubert_embedding(raw_audio):
|
|
| 330 |
|
| 331 |
@spaces.GPU(duration=120)
|
| 332 |
@torch.no_grad()
|
| 333 |
-
def compute_wavlm_embedding(raw_audio):
|
| 334 |
"""Compute embeddings from audio"""
|
| 335 |
audio = rearrange(raw_audio, "(f s) -> f s", s=640)
|
| 336 |
|
|
@@ -369,7 +332,7 @@ def compute_wavlm_embedding(raw_audio):
|
|
| 369 |
|
| 370 |
|
| 371 |
@torch.no_grad()
|
| 372 |
-
def extract_video_landmarks(video_frames):
|
| 373 |
"""Extract landmarks from video frames"""
|
| 374 |
|
| 375 |
# Create a progress bar for Gradio
|
|
@@ -666,6 +629,57 @@ def process_video(video_input, audio_input, max_num_seconds):
|
|
| 666 |
duration=10,
|
| 667 |
)
|
| 668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
# Use default media if none provided
|
| 670 |
if video_input is None:
|
| 671 |
video_input = DEFAULT_VIDEO_PATH
|
|
@@ -749,9 +763,9 @@ def process_video(video_input, audio_input, max_num_seconds):
|
|
| 749 |
|
| 750 |
# Compute video embeddings and landmarks - store full version in cache
|
| 751 |
video_embedding, video_frames = compute_video_embedding(
|
| 752 |
-
video_reader, len(video_reader)
|
| 753 |
)
|
| 754 |
-
video_landmarks = extract_video_landmarks(video_frames)
|
| 755 |
|
| 756 |
# Update video cache with full versions
|
| 757 |
cache["video"]["path"] = video_path_hash
|
|
@@ -807,8 +821,8 @@ def process_video(video_input, audio_input, max_num_seconds):
|
|
| 807 |
print("Computing audio embeddings")
|
| 808 |
|
| 809 |
# Compute audio embeddings with the truncated audio
|
| 810 |
-
hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
|
| 811 |
-
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
|
| 812 |
|
| 813 |
# Update audio cache with full embeddings
|
| 814 |
# Note: raw_audio was already cached above
|
|
|
|
| 186 |
# landmarks_extractor,
|
| 187 |
# ) = load_all_models()
|
| 188 |
|
| 189 |
+
keyframe_model = None
|
| 190 |
+
interpolation_model = None
|
| 191 |
+
vae_model = None
|
| 192 |
+
hubert_model = None
|
| 193 |
+
wavlm_model = None
|
| 194 |
+
landmarks_extractor = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
@spaces.GPU(duration=60)
|
| 198 |
@torch.no_grad()
|
| 199 |
+
def compute_video_embedding(video_reader, min_len, vae_model):
|
| 200 |
"""Compute embeddings from video"""
|
| 201 |
|
| 202 |
total_frames = min_len
|
|
|
|
| 246 |
|
| 247 |
@spaces.GPU(duration=120)
|
| 248 |
@torch.no_grad()
|
| 249 |
+
def compute_hubert_embedding(raw_audio, hubert_model):
|
| 250 |
"""Compute embeddings from audio"""
|
| 251 |
print(f"Computing audio embedding from {raw_audio.shape}")
|
| 252 |
|
|
|
|
| 293 |
|
| 294 |
@spaces.GPU(duration=120)
|
| 295 |
@torch.no_grad()
|
| 296 |
+
def compute_wavlm_embedding(raw_audio, wavlm_model):
|
| 297 |
"""Compute embeddings from audio"""
|
| 298 |
audio = rearrange(raw_audio, "(f s) -> f s", s=640)
|
| 299 |
|
|
|
|
| 332 |
|
| 333 |
|
| 334 |
@torch.no_grad()
|
| 335 |
+
def extract_video_landmarks(video_frames, landmarks_extractor):
|
| 336 |
"""Extract landmarks from video frames"""
|
| 337 |
|
| 338 |
# Create a progress bar for Gradio
|
|
|
|
| 629 |
duration=10,
|
| 630 |
)
|
| 631 |
|
| 632 |
+
if vae_model is None:
|
| 633 |
+
vae_model = VaeWrapper("video")
|
| 634 |
+
vae_model = vae_model.half() # Convert to half precision
|
| 635 |
+
try:
|
| 636 |
+
vae_model = torch.compile(vae_model)
|
| 637 |
+
print("Successfully compiled vae_model in FP16")
|
| 638 |
+
except Exception as e:
|
| 639 |
+
print(f"Warning: Failed to compile vae_model: {e}")
|
| 640 |
+
|
| 641 |
+
if hubert_model is None:
|
| 642 |
+
hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
|
| 643 |
+
hubert_model = hubert_model.half() # Convert to half precision
|
| 644 |
+
try:
|
| 645 |
+
hubert_model = torch.compile(hubert_model)
|
| 646 |
+
print("Successfully compiled hubert_model in FP16")
|
| 647 |
+
except Exception as e:
|
| 648 |
+
print(f"Warning: Failed to compile hubert_model: {e}")
|
| 649 |
+
|
| 650 |
+
if wavlm_model is None:
|
| 651 |
+
wavlm_model = WavLM_wrapper(
|
| 652 |
+
model_size="Base+",
|
| 653 |
+
feed_as_frames=False,
|
| 654 |
+
merge_type="None",
|
| 655 |
+
model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
|
| 656 |
+
).cuda()
|
| 657 |
+
|
| 658 |
+
wavlm_model = wavlm_model.half() # Convert to half precision
|
| 659 |
+
try:
|
| 660 |
+
wavlm_model = torch.compile(wavlm_model)
|
| 661 |
+
print("Successfully compiled wavlm_model in FP16")
|
| 662 |
+
except Exception as e:
|
| 663 |
+
print(f"Warning: Failed to compile wavlm_model: {e}")
|
| 664 |
+
|
| 665 |
+
if landmarks_extractor is None:
|
| 666 |
+
landmarks_extractor = LandmarksExtractor()
|
| 667 |
+
|
| 668 |
+
if keyframe_model is None:
|
| 669 |
+
keyframe_model = load_model(
|
| 670 |
+
config="keyframe.yaml",
|
| 671 |
+
ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
if interpolation_model is None:
|
| 675 |
+
interpolation_model = load_model(
|
| 676 |
+
config="interpolation.yaml",
|
| 677 |
+
ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
keyframe_model.en_and_decode_n_samples_a_time = 2
|
| 681 |
+
interpolation_model.en_and_decode_n_samples_a_time = 2
|
| 682 |
+
|
| 683 |
# Use default media if none provided
|
| 684 |
if video_input is None:
|
| 685 |
video_input = DEFAULT_VIDEO_PATH
|
|
|
|
| 763 |
|
| 764 |
# Compute video embeddings and landmarks - store full version in cache
|
| 765 |
video_embedding, video_frames = compute_video_embedding(
|
| 766 |
+
video_reader, len(video_reader), vae_model
|
| 767 |
)
|
| 768 |
+
video_landmarks = extract_video_landmarks(video_frames, landmarks_extractor)
|
| 769 |
|
| 770 |
# Update video cache with full versions
|
| 771 |
cache["video"]["path"] = video_path_hash
|
|
|
|
| 821 |
print("Computing audio embeddings")
|
| 822 |
|
| 823 |
# Compute audio embeddings with the truncated audio
|
| 824 |
+
hubert_embedding = compute_hubert_embedding(raw_audio_reshape, hubert_model)
|
| 825 |
+
wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape, wavlm_model)
|
| 826 |
|
| 827 |
# Update audio cache with full embeddings
|
| 828 |
# Note: raw_audio was already cached above
|