Spaces:
Runtime error
Runtime error
| import os, subprocess, shlex, sys, gc | |
| import time | |
| import torch | |
| import numpy as np | |
| import shutil | |
| import argparse | |
| import gradio as gr | |
| import uuid | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| # | |
| subprocess.run(shlex.split("pip install wheel/torch_scatter-2.1.2+pt21cu121-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/flash_attn-2.6.3+cu123torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86_64.whl")) | |
| from src.utils.visualization_utils import render_video_from_file | |
| from src.model import LSM_MASt3R | |
| # Download the model checkpoint from Hugging Face Hub | |
| repo_id = "Journey9ni/LSM" | |
| remote_dir = "checkpoints/pretrained_models" | |
| local_dir = "checkpoints/pretrained_model" | |
| model_path_map = { | |
| "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth": "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth", | |
| "checkpoint-40.pth":"checkpoint-40.pth", | |
| "demo_e200.ckpt":"lang_seg.ckpt" | |
| } | |
| os.makedirs(local_dir, exist_ok=True) | |
| # download remote repo | |
| snapshot_download(repo_id=repo_id, local_dir='./') | |
| # rename the files | |
| for remote_name, local_name in model_path_map.items(): | |
| os.rename(os.path.join(remote_dir, remote_name), os.path.join(local_dir, local_name)) | |
| # load the model | |
| model_path = "checkpoints/pretrained_model/checkpoint-40.pth" | |
| model = LSM_MASt3R.from_pretrained(model_path, device='cuda') | |
| model = model.eval() | |
| def process(inputfiles, input_path=None): | |
| # Create a unique cache directory | |
| cache_dir = os.path.join('outputs', str(uuid.uuid4())) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| if input_path is not None: | |
| imgs_path = './assets/examples/' + input_path | |
| imgs_names = sorted(os.listdir(imgs_path)) | |
| inputfiles = [] | |
| for imgs_name in imgs_names: | |
| file_path = os.path.join(imgs_path, imgs_name) | |
| print(file_path) | |
| inputfiles.append(file_path) | |
| print(inputfiles) | |
| filelist = inputfiles | |
| if len(filelist) != 2: | |
| gr.Warning("Please select 2 images") | |
| shutil.rmtree(cache_dir) # Clean up cache directory | |
| return None, None, None, None, None, None | |
| ply_path = os.path.join(cache_dir, 'gaussians.ply') | |
| # render_video_from_file(filelist, model, output_path=cache_dir, resolution=224) | |
| render_video_from_file(filelist, model, output_path=cache_dir, resolution=512) | |
| rgb_video_path = os.path.join(cache_dir, 'moved', 'output_images_video.mp4') | |
| depth_video_path = os.path.join(cache_dir, 'moved', 'output_depth_video.mp4') | |
| feature_video_path = os.path.join(cache_dir, 'moved', 'output_fmap_video.mp4') | |
| return filelist, rgb_video_path, depth_video_path, feature_video_path, ply_path, ply_path | |
| _TITLE = 'LargeSpatialModel' | |
| _DESCRIPTION = ''' | |
| <div style="display: flex; justify-content: center; align-items: center;"> | |
| <div style="width: 100%; text-align: center; font-size: 30px;"> | |
| <strong>Large Spatial Model: End-to-end Unposed Images to Semantic 3D</strong> | |
| </div> | |
| </div> | |
| <p></p> | |
| <div align="center"> | |
| <a style="display:inline-block" href="https://arxiv.org/abs/2410.18956"><img src="https://img.shields.io/badge/ArXiv-2410.18956-b31b1b?logo=arxiv" alt='arxiv'></a> | |
| <a style="display:inline-block" href="https://largespatialmodel.github.io/"><img src='https://img.shields.io/badge/Project_Page-ff7512?logo=lightning'></a> | |
| <a title="Social" href="https://x.com/WayneINR" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a> | |
| </div> | |
| <p></p> | |
| * Official demo of: [LargeSpatialModel: End-to-end Unposed Images to Semantic 3D](https://largespatialmodel.github.io/). | |
| * Examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data. | |
| ''' | |
| block = gr.Blocks().queue() | |
| with block: | |
| gr.Markdown(_DESCRIPTION) | |
| with gr.Column(variant="panel"): | |
| with gr.Tab("Input"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| inputfiles = gr.File(file_count="multiple", label="Load Images") | |
| input_path = gr.Textbox(visible=False, label="example_path") | |
| with gr.Column(scale=1): | |
| image_gallery = gr.Gallery( | |
| label="Gallery", | |
| show_label=False, | |
| elem_id="gallery", | |
| columns=[2], | |
| height=300, # Fixed height | |
| object_fit="cover" # Ensure images fill the space | |
| ) | |
| button_gen = gr.Button("Start Reconstruction", elem_id="button_gen") | |
| processing_msg = gr.Markdown("Processing...", visible=False, elem_id="processing_msg") | |
| with gr.Column(variant="panel"): | |
| with gr.Tab("Output"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| rgb_video = gr.Video(label="RGB Video", autoplay=True) | |
| with gr.Column(scale=1): | |
| feature_video = gr.Video(label="Feature Video", autoplay=True) | |
| with gr.Column(scale=1): | |
| depth_video = gr.Video(label="Depth Video", autoplay=True) | |
| with gr.Row(): | |
| with gr.Group(): | |
| output_model = gr.Model3D( | |
| label="3D Dense Model under Gaussian Splats Formats, need more time to visualize", | |
| interactive=False, | |
| camera_position=[0.5, 0.5, 1], # Slight offset for better model viewing | |
| height=600, | |
| ) | |
| gr.Markdown( | |
| """ | |
| <div class="model-description"> | |
| Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move. | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| output_file = gr.File(label="PLY File") | |
| examples = gr.Examples( | |
| examples=[ | |
| "sofa", | |
| ], | |
| inputs=[input_path], | |
| outputs=[image_gallery, rgb_video, depth_video, feature_video, output_model, output_file], | |
| fn=lambda x: process(inputfiles=None, input_path=x), | |
| cache_examples=True, | |
| label="Examples" | |
| ) | |
| button_gen.click( | |
| process, | |
| inputs=[inputfiles], | |
| outputs=[image_gallery, rgb_video, depth_video, feature_video, output_model, output_file], | |
| ) | |
| block.launch(server_name="0.0.0.0", share=False) | |