Spaces:
Sleeping
Sleeping
| import random | |
| import pandas as pd | |
| import gradio as gr | |
| from typing import Dict, Optional | |
| import unibox as ub | |
| # Store current dataset in a global dict so it persists across Gradio calls. | |
| CURRENT_DATASET = { | |
| "id": None, | |
| "df": None | |
| } | |
| rating_map = { | |
| "g": "general", | |
| "s": "sensitive", | |
| "q": "questionable", | |
| "e": "explicit" | |
| } | |
| def load_dataset_if_needed(dataset_id: str): | |
| """ | |
| Checks if dataset_id is different from what's currently loaded. | |
| If so, loads from HF again and updates CURRENT_DATASET. | |
| """ | |
| if CURRENT_DATASET["id"] != dataset_id: | |
| df = ub.loads(f"hf://{dataset_id}").to_pandas() | |
| CURRENT_DATASET["id"] = dataset_id | |
| CURRENT_DATASET["df"] = df | |
| def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str: | |
| """ | |
| 1girl long_hair blush -> 1girl, long_hair, blush | |
| """ | |
| tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i] | |
| if shuffle: | |
| random.shuffle(tags_list) | |
| return ", ".join(tags_list) | |
| def get_tags_dict(df_row: pd.Series) -> dict: | |
| """ | |
| Returns a dict with rating/artist/character/copyright/general/meta | |
| plus numeric score. | |
| """ | |
| rating = df_row["rating"] | |
| artist = df_row["tag_string_artist"] | |
| character = df_row["tag_string_character"] | |
| copyright_ = df_row["tag_string_copyright"] | |
| general = df_row["tag_string_general"] | |
| meta = df_row["tag_string_meta"] | |
| score = df_row["score"] | |
| rating_str = rating_map.get(rating, "") | |
| artist_str = artist if artist else "" | |
| character_str = convert_dbr_tag_string(character) if character else "" | |
| copyright_str = f"copyright:{copyright_}" if copyright_ else "" | |
| general_str = convert_dbr_tag_string(general) if general else "" | |
| meta_str = convert_dbr_tag_string(meta) if meta else "" | |
| _score = str(score) if score else "" | |
| return { | |
| "rating_str": rating_str, | |
| "artist_str": artist_str, | |
| "character_str": character_str, | |
| "copyright_str": copyright_str, | |
| "general_str": general_str, | |
| "meta_str": meta_str, | |
| "score": _score, | |
| } | |
| def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str: | |
| """ | |
| Build a final comma-separated string (rating, artist, character, etc.). | |
| """ | |
| context = [] | |
| if tags_dict["rating_str"]: | |
| context.append(tags_dict["rating_str"]) | |
| if tags_dict["artist_str"] and add_artist_tags: | |
| context.append(f"artist:{tags_dict['artist_str']}") | |
| if tags_dict["character_str"]: | |
| context.append(tags_dict["character_str"]) | |
| if tags_dict["copyright_str"]: | |
| context.append(tags_dict["copyright_str"]) | |
| if tags_dict["general_str"]: | |
| context.append(tags_dict["general_str"]) | |
| return ", ".join(context) | |
| def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5, | |
| tags_front: str = "", tags_back: str = "", | |
| add_artist_tags: bool = True) -> list: | |
| filtered_df = df.iloc[start_idx:end_idx] | |
| captions = [] | |
| for _, row in filtered_df.iterrows(): | |
| tags = get_tags_dict(row) | |
| caption_base = build_tags_from_tags_dict(tags, add_artist_tags) | |
| # Combine front, base, back | |
| pieces = [part for part in [tags_front, caption_base, tags_back] if part] | |
| final_caption = ", ".join(pieces) | |
| captions.append(final_caption) | |
| return captions | |
| def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list: | |
| filtered_df = df.iloc[start_idx:end_idx] | |
| return [row["large_file_url"] for _, row in filtered_df.iterrows()] | |
| def gradio_interface( | |
| dataset_id: str, | |
| start_idx: int = 0, | |
| display_count: int = 5, | |
| tags_front: str = "", | |
| tags_back: str = "", | |
| add_artist_tags: bool = True | |
| ): | |
| """ | |
| 1) Loads dataset if needed | |
| 2) Returns (DataFrame, Gallery, InfoMessage) | |
| """ | |
| # 1) Possibly reload | |
| load_dataset_if_needed(dataset_id) | |
| dset_df = CURRENT_DATASET["df"] | |
| if dset_df is None: | |
| return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}" | |
| # 2) Figure out total length, clamp inputs | |
| total_len = len(dset_df) | |
| if total_len == 0: | |
| return pd.DataFrame(), [], f"Dataset {dataset_id} is empty." | |
| start_idx = max(start_idx, 0) | |
| if start_idx >= total_len: | |
| start_idx = total_len - 1 | |
| end_idx = start_idx + display_count | |
| if end_idx > total_len: | |
| end_idx = total_len | |
| # 3) Build results | |
| idxs = range(start_idx, end_idx) | |
| captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags) | |
| previews = get_previews_for_rows(dset_df, start_idx, end_idx) | |
| df_out = pd.DataFrame({"index": idxs, "Captions": captions}) | |
| # 4) Build info string | |
| info_msg = ( | |
| f"**Current dataset:** {CURRENT_DATASET['id']} \n" | |
| f"**Dataset length:** {total_len} \n" | |
| f"**start_idx:** {start_idx}, **display_count:** {display_count}, " | |
| f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', " | |
| f"**add_artist_tags:** {add_artist_tags}" | |
| ) | |
| return df_out, previews, info_msg | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Danbooru2025 Dataset Captions and Previews") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| dataset_id_input = gr.Textbox( | |
| value="dataproc5/test-danbooru2025-tag-balanced-2k", | |
| label="Dataset ID" | |
| ) | |
| start_idx_input = gr.Number(value=500, label="Start Index") | |
| display_count_input = gr.Slider( | |
| value=5, minimum=1, maximum=50, step=1, | |
| label="Number of Items" | |
| ) | |
| tags_front_input = gr.Textbox(value="", label="Tags Front") | |
| tags_back_input = gr.Textbox(value="", label="Tags Back") | |
| add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True) | |
| run_button = gr.Button("Get Captions & Previews") | |
| with gr.Column(scale=2): | |
| captions_df_out = gr.DataFrame(label="Captions") | |
| previews_gallery_out = gr.Gallery(label="Previews", type="filepath") | |
| info_textbox_out = gr.Markdown(value="") | |
| run_button.click( | |
| fn=gradio_interface, | |
| inputs=[ | |
| dataset_id_input, | |
| start_idx_input, | |
| display_count_input, | |
| tags_front_input, | |
| tags_back_input, | |
| add_artist_tags_input | |
| ], | |
| outputs=[ | |
| captions_df_out, | |
| previews_gallery_out, | |
| info_textbox_out | |
| ] | |
| ) | |
| demo.launch() |