Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| from skimage.color import rgb2lab, lab2rgb | |
| import skimage.transform | |
| from PIL import Image | |
| import os | |
| from tqdm import tqdm | |
| from moviepy.editor import VideoFileClip, AudioFileClip | |
| from moviepy.tools import cvsecs | |
| import cv2 | |
| from pdb import set_trace | |
| def lab_to_rgb(L, ab): | |
| """ | |
| Takes a batch of images | |
| """ | |
| L = (L + 1.) * 50. | |
| ab = ab * 110. | |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
| rgb_imgs = [] | |
| for img in Lab: | |
| img_rgb = lab2rgb(img) | |
| rgb_imgs.append(img_rgb) | |
| return np.stack(rgb_imgs, axis=0) | |
| SIZE = 256 | |
| def get_L(img): | |
| img = transforms.Resize( | |
| (SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img) | |
| img = np.array(img) | |
| img_lab = rgb2lab(img).astype("float32") | |
| img_lab = transforms.ToTensor()(img_lab) | |
| L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1 | |
| return L | |
| def get_predictions(model, L): | |
| # model.L = L.to(model.device) | |
| model.eval() | |
| with torch.no_grad(): | |
| model.L = L.to(torch.device('cpu')) | |
| model.forward() | |
| fake_color = model.fake_color.detach() | |
| fake_imgs = lab_to_rgb(L, fake_color) | |
| return fake_imgs | |
| def colorize_img(model, img): | |
| L = get_L(img) | |
| L = L[None] # put in list | |
| fake_imgs = get_predictions(model, L) | |
| fake_img = fake_imgs[0] # get out of list | |
| resized_fake_img = skimage.transform.resize( | |
| fake_img, img.size[::-1]) # reshape to original size | |
| return resized_fake_img | |
| def valid_start_end(duration, start_input, end_input): | |
| start = start_input | |
| end = end_input | |
| if start == '': | |
| start = 0 | |
| if end == '': | |
| end = duration | |
| try: | |
| start = cvsecs(start) | |
| end = cvsecs(end) | |
| except BaseException: | |
| # start, end aren't actual time values. | |
| raise Exception("Invalid start, end values") | |
| # make it minimal maximum length | |
| start = max(start, 0) | |
| end = min(duration, end) | |
| # start must be less than end | |
| if start >= end: | |
| raise Exception("Start must be before end.") | |
| return start, end | |
| def colorize_vid(path_input, model, fps, start_input, end_input): | |
| original_video = VideoFileClip(path_input) | |
| # validate start, end | |
| start, end = valid_start_end( | |
| original_video.duration, start_input, end_input) | |
| input_video = original_video.subclip(start, end) | |
| if isinstance(fps, int): | |
| used_fps = fps | |
| nframes = np.round(fps * input_video.duration) | |
| else: | |
| used_fps = input_video.fps | |
| nframes = input_video.reader.nframes | |
| print( | |
| f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.") | |
| frames = input_video.iter_frames(fps=used_fps) | |
| # create tmp path that is same as input path but with '_tmp.[suffix]' | |
| base_path, suffix = os.path.splitext(path_input) | |
| path_video_tmp = base_path + "_tmp" + suffix | |
| # create video writer for output | |
| size = input_video.size | |
| out = cv2.VideoWriter( | |
| path_video_tmp, | |
| cv2.VideoWriter_fourcc( | |
| *'mp4v'), | |
| used_fps, | |
| size) | |
| # out = cv2.VideoWriter(path_video_tmp, cv2.VideoWriter_fourcc(*'DIVX'), used_fps, size) | |
| for frame in tqdm(frames, total=nframes): | |
| # get colorized frame | |
| color_frame = colorize_img(model, Image.fromarray(frame)) | |
| if color_frame.max() <= 1: | |
| color_frame = (color_frame * 255).astype(np.uint8) | |
| color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB) | |
| out.write(color_frame) | |
| out.release() | |
| # create output path that is same as input path but with '_out.[suffix]' | |
| path_output = base_path + "_out" + suffix | |
| # for some reason, subclip doesn't save audio. so make tmp audio file | |
| path_audio_tmp = base_path + "audio_tmp.mp3" | |
| input_video.audio.write_audiofile(path_audio_tmp, logger=None) | |
| input_audio = AudioFileClip(path_audio_tmp) | |
| output_video = VideoFileClip(path_video_tmp) | |
| output_video = output_video.set_audio(input_audio) | |
| output_video.write_videofile(path_output, logger=None) | |
| os.remove(path_video_tmp) | |
| os.remove(path_audio_tmp) | |
| print("Done.") | |
| return path_output | |