from gradio_imageslider import ImageSlider
import functools
import os
import tempfile
import diffusers
import gradio as gr
import imageio as imageio
import numpy as np
import spaces
import torch as torch
from PIL import Image, ImageFilter
from tqdm import tqdm
from pathlib import Path
import gradio
from gradio.utils import get_cache_folder
from infer import lotus, lotus_video
import transformers
from huggingface_hub import login
import cv2
transformers.utils.move_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
def apply_gaussian_blur(image, radius=1.0):
"""Apply Gaussian blur to PIL Image with specified radius"""
return image.filter(ImageFilter.GaussianBlur(radius=radius))
class NormalMapSimple:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "normal_map"
CATEGORY = "image/filters"
def normal_map(self, images, scale_XY):
t = images.detach().clone().cpu().numpy().astype(np.float32)
L = np.mean(t[:,:,:,:3], axis=3)
for i in range(t.shape[0]):
t[i,:,:,0] = cv2.Scharr(L[i], -1, 1, 0, cv2.BORDER_REFLECT) * -1
t[i,:,:,1] = cv2.Scharr(L[i], -1, 0, 1, cv2.BORDER_REFLECT)
t[:,:,:,2] = 1
t = torch.from_numpy(t)
t[:,:,:,:2] *= scale_XY
t[:,:,:,:3] = torch.nn.functional.normalize(t[:,:,:,:3], dim=3) / 2 + 0.5
return (t,)
class ConvertNormals:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"normals": ("IMAGE",),
"input_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
"output_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
"scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
"normalize": ("BOOLEAN", {"default": True}),
"fix_black": ("BOOLEAN", {"default": True}),
},
"optional": {
"optional_fill": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "convert_normals"
CATEGORY = "image/filters"
def convert_normals(self, normals, input_mode, output_mode, scale_XY, normalize, fix_black, optional_fill=None):
try:
t = normals.detach().clone()
if input_mode == "BAE":
t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
elif input_mode == "MiDaS":
t[:,:,:,:3] = torch.stack([1 - t[:,:,:,2], t[:,:,:,1], t[:,:,:,0]], dim=3) # BGR -> RGB and invert R
elif input_mode == "DirectX":
t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
if fix_black:
key = torch.clamp(1 - t[:,:,:,2] * 2, min=0, max=1)
if optional_fill is None:
t[:,:,:,0] += key * 0.5
t[:,:,:,1] += key * 0.5
t[:,:,:,2] += key
else:
fill = optional_fill.detach().clone()
if fill.shape[1:3] != t.shape[1:3]:
fill = torch.nn.functional.interpolate(fill.movedim(-1,1), size=(t.shape[1], t.shape[2]), mode='bilinear').movedim(1,-1)
if fill.shape[0] != t.shape[0]:
fill = fill[0].unsqueeze(0).expand(t.shape[0], -1, -1, -1)
t[:,:,:,:3] += fill[:,:,:,:3] * key.unsqueeze(3).expand(-1, -1, -1, 3)
t[:,:,:,:2] = (t[:,:,:,:2] - 0.5) * scale_XY + 0.5
if normalize:
# Transform to [-1, 1] range
t_norm = t[:,:,:,:3] * 2 - 1
# Calculate the length of each vector
lengths = torch.sqrt(torch.sum(t_norm**2, dim=3, keepdim=True))
# Avoid division by zero
lengths = torch.clamp(lengths, min=1e-6)
# Normalize each vector to unit length
t_norm = t_norm / lengths
# Transform back to [0, 1] range
t[:,:,:,:3] = (t_norm + 1) / 2
if output_mode == "BAE":
t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
elif output_mode == "MiDaS":
t[:,:,:,:3] = torch.stack([t[:,:,:,2], t[:,:,:,1], 1 - t[:,:,:,0]], dim=3) # invert R and BGR -> RGB
elif output_mode == "DirectX":
t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
return (t,)
except Exception as e:
print(f"Error in convert_normals: {str(e)}")
return (normals,)
def get_image_intensity(img, gamma_correction=1.0):
"""
Extract intensity map from an image using HSV color space
"""
# Convert to HSV color space
result = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
# Extract Value channel (intensity)
result = result[:, :, 2].astype(np.float32) / 255.0
# Apply gamma correction
result = result ** gamma_correction
# Convert back to 0-255 range
result = (result * 255.0).clip(0, 255).astype(np.uint8)
# Convert to RGB (still grayscale but in RGB format)
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
return result
def blend_numpy_images(image1, image2, blend_factor=0.25, mode="normal"):
"""
Blend two numpy images using normal mode
"""
# Convert to float32 and normalize to 0-1
img1 = image1.astype(np.float32) / 255.0
img2 = image2.astype(np.float32) / 255.0
# Normal blend mode
blended = img1 * (1 - blend_factor) + img2 * blend_factor
# Convert back to uint8
blended = (blended * 255.0).clip(0, 255).astype(np.uint8)
return blended
def process_normal_map(image):
"""
Process image through NormalMapSimple and ConvertNormals
"""
# Convert numpy image to torch tensor with batch dimension
image_tensor = torch.from_numpy(image).unsqueeze(0).float() / 255.0
# Create instances of the classes
normal_map_generator = NormalMapSimple()
normal_converter = ConvertNormals()
# Generate initial normal map
normal_map = normal_map_generator.normal_map(image_tensor, scale_XY=1.0)[0]
# Convert normal map from Standard to Standard (OpenGL)
converted_normal = normal_converter.convert_normals(
normal_map,
input_mode="Standard",
output_mode="Standard",
scale_XY=1.0,
normalize=True,
fix_black=True
)[0]
# Convert back to numpy array
result = (converted_normal.squeeze(0).numpy() * 255).astype(np.uint8)
return result
def infer(path_input, seed=None):
name_base, name_ext = os.path.splitext(os.path.basename(path_input))
_, output_d = lotus(path_input, 'depth', seed, device)
# Apply Gaussian blur with 0.75 radius
output_d = apply_gaussian_blur(output_d, radius=0.75)
# Convert depth to numpy for normal map processing
depth_array = np.array(output_d)
# Load original image for intensity blending
input_image = Image.open(path_input)
input_array = np.array(input_image)
# Get intensity map from original image
intensity_map = get_image_intensity(input_array, gamma_correction=1.0)
# Resize intensity_map to match depth_array dimensions
depth_height, depth_width = depth_array.shape[:2]
if intensity_map.shape[:2] != (depth_height, depth_width):
intensity_map = cv2.resize(intensity_map, (depth_width, depth_height), interpolation=cv2.INTER_LINEAR)
# Blend depth with intensity map
blended_result = blend_numpy_images(
cv2.cvtColor(depth_array, cv2.COLOR_RGB2BGR if len(depth_array.shape) == 3 else cv2.COLOR_GRAY2BGR),
intensity_map,
blend_factor=0.15,
mode="normal"
)
# Generate normal map from blended result
normal_map = process_normal_map(blended_result)
if not os.path.exists("files/output"):
os.makedirs("files/output")
d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
n_save_path = os.path.join("files/output", f"{name_base}_n{name_ext}")
output_d.save(d_save_path)
Image.fromarray(normal_map).save(n_save_path)
return [path_input, d_save_path], [path_input, n_save_path]
def infer_video(path_input, seed=None):
_, frames_d, fps = lotus_video(path_input, 'depth', seed, device)
# Apply Gaussian blur to each frame
blurred_frames = []
for frame in frames_d:
# Convert numpy array to PIL Image if needed
if isinstance(frame, np.ndarray):
frame_pil = Image.fromarray(frame)
else:
frame_pil = frame
# Apply blur and convert back to numpy array
blurred_frame = apply_gaussian_blur(frame_pil, radius=0.75)
blurred_frames.append(np.array(blurred_frame))
if not os.path.exists("files/output"):
os.makedirs("files/output")
name_base, _ = os.path.splitext(os.path.basename(path_input))
d_save_path = os.path.join("files/output", f"{name_base}_d.mp4")
imageio.mimsave(d_save_path, blurred_frames, fps=fps)
return d_save_path
def run_demo_server():
infer_gpu = spaces.GPU(functools.partial(infer))
infer_video_gpu = spaces.GPU(functools.partial(infer_video))
gradio_theme = gr.themes.Default()
with gr.Blocks(
theme=gradio_theme,
title="LOTUS (Depth & Normal Maps - Discriminative)",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
""",
head="""
""",
) as demo:
with gr.Tabs(elem_classes=["tabs"]):
with gr.Tab("IMAGE"):
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Input Image",
type="filepath",
)
with gr.Row():
image_submit_btn = gr.Button(
value="Predict Depth!", variant="primary"
)
image_reset_btn = gr.Button(value="Reset")
with gr.Column():
image_output_d = ImageSlider(
label="Depth Output (Discriminative)",
type="filepath",
interactive=False,
elem_classes="slider",
position=0.25,
)
image_output_n = ImageSlider(
label="OpenGL Normal Map Output",
type="filepath",
interactive=False,
elem_classes="slider",
position=0.25,
)
gr.Examples(
fn=infer_gpu,
examples=sorted([
[os.path.join("files", "images", name)]
for name in os.listdir(os.path.join("files", "images"))
]),
inputs=[image_input],
outputs=[image_output_d, image_output_n],
cache_examples=False,
)
with gr.Tab("VIDEO"):
with gr.Row():
with gr.Column():
input_video = gr.Video(
label="Input Video",
autoplay=True,
loop=True,
)
with gr.Row():
video_submit_btn = gr.Button(
value="Predict Depth!", variant="primary"
)
video_reset_btn = gr.Button(value="Reset")
with gr.Column():
video_output_d = gr.Video(
label="Depth Output (Discriminative)",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
)
gr.Examples(
fn=infer_video_gpu,
examples=sorted([
[os.path.join("files", "videos", name)]
for name in os.listdir(os.path.join("files", "videos"))
]),
inputs=[input_video],
outputs=[video_output_d],
cache_examples=False,
)
### Image
image_submit_btn.click(
fn=infer_gpu,
inputs=[image_input],
outputs=[image_output_d, image_output_n],
concurrency_limit=1,
)
image_reset_btn.click(
fn=lambda: [None, None],
inputs=[],
outputs=[image_output_d, image_output_n],
queue=False,
)
### Video
video_submit_btn.click(
fn=infer_video_gpu,
inputs=[input_video],
outputs=[video_output_d],
queue=True,
)
video_reset_btn.click(
fn=lambda: None,
inputs=[],
outputs=[video_output_d],
)
### Server launch
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)
def main():
os.system("pip freeze")
if os.path.exists("files/output"):
os.system("rm -rf files/output")
run_demo_server()
if __name__ == "__main__":
main()