| from autogen import GroupChatManager |
| import json |
| import re, os |
| import networkx as nx |
|
|
| from agents import create_parse_agents, create_graph_agents, language_summary_agents, calculation_summary_agents |
| from agents import is_termination_msg, is_termination_require, gpt4_config |
| from corrector_agents import get_corrector_agents |
| from refiner_agents import get_refiner_agents |
|
|
| from chats import InputParserGroupChat, RequirementGroupChat, LanguageGroupChat, CalculationGroupChat, SceneGraphGroupChat, SchemaGroupChat, LayoutCorrectorGroupChat, ObjectDeletionGroupChat, LayoutRefinerGroupChat |
|
|
| from utils import get_room_priors, extract_list_from_json |
| from utils import preprocess_scene_graph, build_graph, remove_unnecessary_edges, handle_under_prepositions, get_conflicts, get_size_conflicts, get_object_from_scene_graph |
| from utils import get_object_from_scene_graph, get_rotation, get_cluster_objects, clean_and_extract_edges |
| from utils import get_cluster_size |
| from utils import get_possible_positions, is_point_bbox, calculate_overlap, get_topological_ordering, place_object, get_depth, get_visualization |
| import openshape |
| import torch |
| import numpy as np |
| import transformers |
| import threading |
| import multiprocessing |
| import sys, shutil |
| import pandas as pd |
| from torch.nn import functional as F |
| import objaverse |
| import trimesh |
| import certifi |
| import ssl |
|
|
| ssl._create_default_https_context = ssl._create_unverified_context |
| os.environ['SSL_CERT_FILE'] = certifi.where() |
|
|
| class Generator: |
| def __init__(self, layout_elements=['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'], room_dimensions=[5.0, 5.0, 3.0], result_file="./results/layout_w_cot.json"): |
| |
| self.room_dimensions = room_dimensions |
| self.room_priors = get_room_priors(self.room_dimensions) |
| |
| self.layout_elements = list(layout_elements) |
| self.result_file = result_file |
| self.scene_graph = None |
| self.cot_info = {} |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| meta = json.load( |
| open('./embeddings/objaverse_meta.json') |
| ) |
| self.meta = {x['u']: x for x in meta['entries']} |
|
|
| deser = torch.load('./embeddings/objaverse.pt') |
| self.us = deser['us'] |
| self.feats = deser['feats'] |
|
|
| local_assets = pd.read_excel("./assets/copy.xlsx", skiprows=2) |
| captions = local_assets["caption_clip"].tolist() |
|
|
| file_paths = [] |
| bbx_values = [] |
| for index, row in local_assets.iterrows(): |
| model_name = row['name_en'] |
| model_path = os.path.join("./assets/lvm_2032fbx", f"{model_name}.fbx") |
| file_paths.append(model_path) |
| bbx_values.append(row['bbx']) |
|
|
| self.caption_to_file = [ |
| { |
| "caption": caption, |
| "file_path": path, |
| "bbx": bbx |
| } |
| for caption, path, bbx in zip(captions, file_paths, bbx_values) |
| ] |
| |
|
|
| self.clip_model, self.clip_prep = transformers.CLIPModel.from_pretrained( |
| "./ckpts/CLIP-ViT-bigG-14-laion2B-39B-b160k", |
| low_cpu_mem_usage=True, torch_dtype=torch.float16, |
| offload_state_dict=True, |
| ), transformers.CLIPProcessor.from_pretrained("./ckpts/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
|
|
| self.local_embeddings = torch.load("./embeddings/local.pt") |
|
|
|
|
| def parse_input(self, user_input, max_number_of_objects): |
| self.user_input = user_input |
| self.max_number_of_objects = max_number_of_objects |
| user_proxy, requirements_analyzer, substructure_analyzer, substructure_analyzer_checker, interior_designer, designer_checker = create_parse_agents(self.max_number_of_objects) |
|
|
| init_groupchat = RequirementGroupChat( |
| agents=[user_proxy, requirements_analyzer, substructure_analyzer, interior_designer, designer_checker], |
| messages=[], |
| max_round=16 |
| ) |
|
|
| manager = GroupChatManager(groupchat=init_groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_require) |
|
|
| user_proxy.initiate_chat( |
| manager, |
| message=f""" |
| The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
| User Input (in triple backquotes): |
| ``` |
| {self.user_input} |
| ``` |
| Room layout elements in the room (in triple backquotes): |
| ``` |
| ['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'] |
| ``` |
| json |
| """, |
| ) |
|
|
| |
| |
| |
| |
| self.designer_response = json.loads(init_groupchat.messages[-2]["content"]) |
| self.cot_info["parse_cot"] = self.designer_response["chain_of_thought"] |
| |
| |
|
|
| def retrieve_local_assets(self): |
| |
|
|
| print("Locking...") |
| sys.clip_move_lock = threading.Lock() |
| print("Locked.") |
|
|
| if torch.cuda.is_available(): |
| with sys.clip_move_lock: |
| self.clip_model.cuda() |
| torch.set_grad_enabled(False) |
| |
|
|
| def preprocess(input_string): |
| wo_numericals = re.sub(r'\d', '', input_string) |
| output = wo_numericals.replace("_", " ") |
| return output |
|
|
| def retrieve_local(query_embedding, top=1, sim_th=0.5): |
| query_embedding = F.normalize(query_embedding.detach().cpu(), dim=-1).squeeze() |
| sims = [] |
| for embedding in torch.split(self.local_embeddings, 10240): |
| sims.append(query_embedding @ F.normalize(embedding.float(), dim=-1).T) |
| sims = torch.cat(sims) |
| sims, indices = torch.sort(sims, descending=True) |
| results = [] |
| for i, sim in zip(indices, sims): |
| if sim > sim_th: |
| results.append({ |
| "caption": self.caption_to_file[i]["caption"], |
| "file_path": self.caption_to_file[i]["file_path"], |
| "bbx": self.caption_to_file[i]["bbx"], |
| "sim": sim.item() |
| }) |
| if len(results) >= top: |
| break |
| return results |
|
|
| def retrieve(embedding, top=1, sim_th=0.1, filter_fn=None): |
| sims = [] |
| embedding = F.normalize(embedding.detach().cpu(), dim=-1).squeeze() |
| for chunk in torch.split(self.feats, 10240): |
| sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T) |
| sims = torch.cat(sims) |
| sims, idx = torch.sort(sims, descending=True) |
| sim_mask = sims > sim_th |
| sims = sims[sim_mask] |
| idx = idx[sim_mask] |
| results = [] |
| for i, sim in zip(idx, sims): |
| if self.us[i] in self.meta: |
| if filter_fn is None or filter_fn(self.meta[self.us[i]]): |
| results.append(dict(self.meta[self.us[i]], sim=sim)) |
| if len(results) >= top: |
| break |
| return results |
|
|
| def get_filter_fn(): |
| face_min = 0 |
| face_max = 34985808 |
| anim_min = 0 |
| anim_max = 563 |
| anim_n = not (anim_min > 0 or anim_max < 563) |
| face_n = not (face_min > 0 or face_max < 34985808) |
| filter_fn = lambda x: ( |
| (anim_n or anim_min <= x['anims'] <= anim_max) |
| and (face_n or face_min <= x['faces'] <= face_max) |
| ) |
| return filter_fn |
|
|
| def get_model_dimensions(file_path): |
| mesh = trimesh.load(file_path) |
| bounding_box = mesh.bounding_box.extents |
| length = bounding_box[0] / 100 |
| width = bounding_box[2] / 100 |
| height = bounding_box[1] / 100 |
| return length, width, height |
|
|
| |
| objects = extract_list_from_json(self.designer_response, 'objects') |
| for obj in objects: |
| text = preprocess("A high-poly " + obj['object_id']) + f" with {obj['material']} material and in {obj['style']} style, high quality" |
| device = self.clip_model.device |
| tn = self.clip_prep( |
| text=[text], return_tensors='pt', truncation=True, max_length=76 |
| ).to(device) |
| enc = self.clip_model.get_text_features(**tn).float().cpu() |
|
|
| retrieved_local = retrieve_local(enc, top=1, sim_th=0.5) |
| if retrieved_local: |
| retrieved_obj = retrieved_local[0] |
| print("Retrieved object: ", retrieved_obj["file_path"]) |
|
|
| |
| |
| |
| source_file = retrieved_obj["file_path"] |
| file_extension = os.path.splitext(source_file)[1] |
| |
| |
| |
|
|
| if retrieved_obj["sim"] > 0.5: |
| length, width, height = map(float, retrieved_obj["bbx"].split(',')) |
| obj['bounding_box_size'] = {'Length': length, 'Width': width, 'Height': height} |
| else: |
| retrieved_obj = retrieve(enc, top=1, sim_th=0.1, filter_fn=get_filter_fn())[0] |
| print(f"Retrieved object from Objaverse: {retrieved_obj['u']}") |
| processes = multiprocessing.cpu_count() |
| objaverse_objects = objaverse.load_objects( |
| uids=[retrieved_obj['u']], |
| download_processes=processes |
| ) |
| |
| |
| |
| for item_id, file_path in objaverse_objects.items(): |
| |
| |
| |
|
|
| if retrieved_obj["sim"] > 0.18: |
| length, width, height = get_model_dimensions(file_path) |
| obj['bounding_box_size'] = {'Length': length, 'Width': width, 'Height': height} |
|
|
| self.designer_response['objects'] = objects |
| print(self.designer_response) |
|
|
| def create_scene_graph(self): |
| cot_data_1 = [] |
| user_proxy, interior_architect, schema_engineer = create_graph_agents() |
| |
| scene_graph_groupchat = SceneGraphGroupChat( |
| agents =[user_proxy, interior_architect, schema_engineer], |
| messages=[], |
| max_round=10 |
| ) |
|
|
| cot_data, json_info, json_data = {}, {}, {} |
| blocks_designer = extract_list_from_json(self.designer_response, 'objects') |
|
|
| for d_block in blocks_designer: |
| object_id = d_block["object_id"] |
| prompt = str(d_block) |
|
|
| manager_scene_graph = GroupChatManager(groupchat=scene_graph_groupchat, |
| llm_config=gpt4_config, |
| human_input_mode="NEVER", |
| is_termination_msg=is_termination_msg) |
|
|
| user_proxy.initiate_chat( |
| manager_scene_graph, |
| message=f""" |
| The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
| User Input (in triple backquotes): |
| ``` |
| {self.user_input} |
| ``` |
| Room layout elements in the room (in triple backquotes): |
| ``` |
| ['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the floor', 'ceiling'] |
| ``` |
| Previously placed objects in the room (in triple backquotes): |
| ``` |
| {json_data} |
| ``` |
| Object to be placed (in triple backticks): |
| ``` |
| {prompt} |
| ``` |
| """, |
| ) |
|
|
| if not json_info: |
| json_info["objects_in_room"] = [] |
| json_info["objects_in_room"] += json.loads(scene_graph_groupchat.messages[-2]["content"])["objects_in_room"] |
| object_data = json.loads(scene_graph_groupchat.messages[-2]["content"])["objects_in_room"][0] |
|
|
| if 'new_object_id' in object_data: |
| del object_data['new_object_id'] |
|
|
| json_data[str(object_id)] = object_data |
|
|
| if str(object_id) not in cot_data: |
| cot_data[str(object_id)] = [] |
|
|
| indices_to_collect = list(range(1, len(scene_graph_groupchat.messages), 2)) |
| for idx in indices_to_collect: |
| cot_data[str(object_id)].append(json.loads(scene_graph_groupchat.messages[idx]["content"])["chain_of_thought"]) |
|
|
| user_proxy.reset(), interior_architect.reset(), schema_engineer.reset(), scene_graph_groupchat.reset() |
|
|
| self.cot_info["scene_graph_cot"] = cot_data |
| self.scene_graph = json_info |
| self.conflict_data = [] |
|
|
| |
| scene_graph = preprocess_scene_graph(json_info["objects_in_room"], cot_data_1) |
| G = build_graph(scene_graph) |
| G = remove_unnecessary_edges(G, cot_data_1) |
| G, scene_graph = handle_under_prepositions(G, scene_graph, cot_data_1) |
| conflicts = get_conflicts(G, scene_graph, cot_data_1) |
|
|
| print("-------------------CONFLICTS-------------------") |
| for conflict in conflicts: |
| print(conflict) |
| print("\n\n") |
| self.conflict_data.append(conflicts) |
|
|
| user_proxy, spatial_corrector_agent, json_schema_debugger, object_deletion_agent = get_corrector_agents() |
|
|
| while len(conflicts) > 0: |
| spatial_corrector_agent.reset(), json_schema_debugger.reset() |
| groupchat = LayoutCorrectorGroupChat( |
| agents =[user_proxy, spatial_corrector_agent, json_schema_debugger], |
| messages=[], |
| max_round=15 |
| ) |
| manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
| user_proxy.initiate_chat( |
| manager, |
| message=f""" |
| {conflicts[0]} |
| """, |
| ) |
| correction = groupchat.messages[-2] |
| pattern = r'```json\s*([^`]+)\s*```' |
| match = re.search(pattern, correction["content"], re.DOTALL).group(1) |
| correction_json = json.loads(match) |
| self.conflict_data.append(correction_json) |
| corr_obj = get_object_from_scene_graph(correction_json["corrected_object"]["new_object_id"], scene_graph) |
| corr_obj["is_on_the_floor"] = correction_json["corrected_object"]["is_on_the_floor"] |
| corr_obj["facing"] = correction_json["corrected_object"]["facing"] |
| corr_obj["placement"] = correction_json["corrected_object"]["placement"] |
| G = build_graph(scene_graph) |
| conflicts = get_conflicts(G, scene_graph, cot_data_1) |
|
|
| size_conflicts = get_size_conflicts(G, scene_graph, cot_data_1, self.user_input, self.room_priors) |
|
|
| print("-------------------SIZE CONFLICTS-------------------") |
| for conflict in size_conflicts: |
| print(conflict) |
| print("\n\n") |
| self.conflict_data.append(size_conflicts) |
|
|
| while len(size_conflicts) > 0: |
| object_deletion_agent.reset() |
| groupchat = ObjectDeletionGroupChat( |
| agents =[user_proxy, object_deletion_agent], |
| messages=[], |
| max_round=2 |
| ) |
| manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
| user_proxy.initiate_chat( |
| manager, |
| message=f""" |
| {size_conflicts[0]} |
| """, |
| ) |
| correction = groupchat.messages[-1] |
| correction_json = json.loads(correction["content"]) |
| object_to_delete = correction_json["object_to_delete"] |
| descendants = nx.descendants(G, object_to_delete) |
| objs_to_delete = descendants.union({object_to_delete}) |
| print("Objs to Delete: ", objs_to_delete) |
| self.conflict_data.append(f"Objs to Delete: {objs_to_delete}") |
| scene_graph = [x for x in scene_graph if x["new_object_id"] not in objs_to_delete] |
| for obj in objs_to_delete: |
| G.remove_node(obj) |
|
|
| size_conflicts = get_size_conflicts(G, scene_graph, cot_data_1, self.user_input, self.room_priors) |
|
|
| self.scene_graph["objects_in_room"] = scene_graph |
|
|
| def summary_language(self): |
| user_proxy, language_architect = language_summary_agents() |
|
|
| groupchat = LanguageGroupChat( |
| agents=[user_proxy, language_architect], |
| messages=[], |
| max_round=2 |
| ) |
|
|
| manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
|
|
| user_proxy.initiate_chat( |
| manager, |
| message=f""" |
| The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
| User Input (in triple backquotes): |
| ``` |
| **chain of thought for requirements_analyzer, substructure_analyzer and interior_designer** |
| {self.cot_info["parse_cot"]} |
| ``` |
| **chain of thought for object placement** |
| {self.cot_info["scene_graph_cot"]} |
| ``` |
| **conflict data** |
| {self.conflict_data} |
| ``` |
| **scene graph** |
| {self.scene_graph} |
| ``` |
| Room layout elements in the room (in triple backquotes): |
| ``` |
| ['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'] |
| ``` |
| json |
| """, |
| ) |
|
|
| self.language_sum = groupchat.messages[-1]["content"] |
|
|
| def create_layout(self, debug=False): |
| |
|
|
| cot_data = [] |
| G = build_graph(self.scene_graph["objects_in_room"]) |
| nodes = G.nodes() |
|
|
| cot_data.append("Calculate constraint area for non-layout objects only.") |
| for node in nodes: |
| if node not in self.layout_elements: |
| cluster_size, _ = get_cluster_size(node, G, self.scene_graph["objects_in_room"], cot_data) |
| node_obj = get_object_from_scene_graph(node, self.scene_graph["objects_in_room"]) |
| cluster_size = {"x_neg" : cluster_size["left of"], "x_pos" : cluster_size["right of"], "y_neg" : cluster_size["behind"], "y_pos" : cluster_size["in front"]} |
| node_obj["cluster"] = {"constraint_area" : cluster_size} |
| cot_data.append(f"The constraint area for {node} is {cluster_size}.") |
|
|
| self.scene_graph = self.scene_graph["objects_in_room"] + self.room_priors |
|
|
| prior_ids = ["south_wall", "north_wall", "east_wall", "west_wall", "ceiling", "middle of the room"] |
| point_bbox = dict.fromkeys([item["new_object_id"] for item in self.scene_graph], False) |
|
|
| |
| for item in self.scene_graph: |
| if item["new_object_id"] in prior_ids: |
| continue |
| possible_pos = get_possible_positions(item["new_object_id"], self.scene_graph, self.room_dimensions, cot_data) |
| |
| overlap = None |
| if len(possible_pos) == 1: |
| overlap = possible_pos[0] |
| elif len(possible_pos) > 1: |
| overlap = possible_pos[0] |
| for pos in possible_pos[1:]: |
| overlap = calculate_overlap(overlap, pos) |
| |
| if overlap is not None and is_point_bbox(overlap) and len(possible_pos) > 0: |
| item["position"] = {"x" : overlap[0], "y" : overlap[2], "z" : overlap[4]} |
| point_bbox[item["new_object_id"]] = True |
|
|
| scene_graph_wo_layout = [item for item in self.scene_graph if item["new_object_id"] not in self.layout_elements] |
|
|
| depth_scene_graph = get_depth(scene_graph_wo_layout) |
| max_depth = max(depth_scene_graph.values()) |
|
|
| topological_order = get_topological_ordering(scene_graph_wo_layout) |
| topological_order = [item for item in topological_order if item not in self.layout_elements] |
|
|
| d = 1 |
| count = 0 |
| while d <= max_depth and count < 20: |
| count += 1 |
| error_flag = False |
|
|
| nodes = [node for node in topological_order if depth_scene_graph[node] == d] |
| if debug: |
| print(f"Nodes at depth {d}: ", nodes) |
|
|
| errors = {} |
|
|
| cot_data.append(f"Place objects: {[node for node in nodes]}.") |
| for node in nodes: |
| if point_bbox[node]: |
| continue |
|
|
| obj = next(item for item in scene_graph_wo_layout if item["new_object_id"] == node) |
| cot_data.append(f"Place the object {obj['new_object_id']} at the depth {d}.") |
| errors = place_object(obj, self.scene_graph, self.room_dimensions, cot_data, errors={}, debug=debug) |
|
|
| if debug: |
| print(f"Errors for {obj['new_object_id']}: ", errors) |
|
|
| |
| if errors: |
| if d > 1: |
| d -= 1 |
| cot_data.append(f"Errors occur for {obj['new_object_id']}: {errors}. Reduce depth to {d}.") |
| if debug: |
| print("Reducing depth to: ", d) |
| else: |
| cot_data.append(f"Errors occur for {obj['new_object_id']} with depth 1: {errors}. The layout creation failed.") |
| print(f"Errors occur for {obj['new_object_id']} with depth 1: {errors}. The layout creation failed.") |
| self.calculation_data = [] |
| return errors |
|
|
| error_flag = True |
| cot_data.append(f"Delete positions for objects at or beyond the current depth {d} in order to reposition the objects.") |
| for del_item in scene_graph_wo_layout: |
| if depth_scene_graph[del_item["new_object_id"]] >= d: |
| if "position" in del_item.keys() and not point_bbox[del_item["new_object_id"]]: |
| if debug: |
| print("Deleting position for: ", del_item["new_object_id"]) |
| del del_item["position"] |
| errors = {} |
| break |
| |
| |
|
|
| if not error_flag: |
| d += 1 |
|
|
| cot_data.append("Save the scene graph.") |
| self.calculation_data = cot_data |
| print(cot_data) |
| print("\n") |
|
|
| os.makedirs("./results", exist_ok=True) |
| jsonname = re.sub(r'[^a-zA-Z0-9]', '_', self.user_input) + '.json' |
| self.result_file = os.path.join("./results", jsonname) |
| with open(self.result_file, "w") as file: |
| json.dump(self.scene_graph, file, indent=4) |
|
|
| def summary_calculation(self): |
| if self.calculation_data: |
| user_proxy, calculation_architect = calculation_summary_agents() |
| groupchat = CalculationGroupChat( |
| agents=[user_proxy, calculation_architect], |
| messages=[], |
| max_round=2 |
| ) |
| manager = GroupChatManager(groupchat=groupchat, llm_config=gpt4_config, is_termination_msg=is_termination_msg) |
|
|
| user_proxy.initiate_chat( |
| manager, |
| message=f""" |
| The room has the size {self.room_dimensions[0]}m x {self.room_dimensions[1]}m x {self.room_dimensions[2]}m |
| User Input (in triple backquotes): |
| ``` |
| {self.calculation_data} |
| ``` |
| Room layout elements in the room (in triple backquotes): |
| ``` |
| ['south_wall', 'north_wall', 'west_wall', 'east_wall', 'middle of the room', 'ceiling'] |
| ``` |
| json |
| """, |
| ) |
|
|
| self.calculation_sum = groupchat.messages[-1]["content"] |
|
|
| os.makedirs("./Results_data", exist_ok=True) |
| filename = re.sub(r'[^a-zA-Z0-9]', '_', self.user_input) + '.md' |
| full_path = os.path.join("./Results_data", filename) |
| with open(full_path, 'w', encoding='utf-8') as file: |
| file.write(self.language_sum) |
| file.write('\n\n## 6. **Object Placement**\n') |
| file.write(self.calculation_sum) |
| else: |
| pass |
|
|