Spaces:
Runtime error
Runtime error
| from pydantic import BaseModel | |
| from promptSearchEngine import PromptSearchEngine | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import streamlit as st | |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
| DATASET = "Gustavosta/Stable-Diffusion-Prompts" | |
| class SearchRequest(BaseModel): | |
| query: str | |
| n: int | None = 5 | |
| # model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts" , split="test[:1%]") | |
| # promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model) | |
| def load_model(): | |
| """Initialize pretrained model for vectorizing. | |
| @st.cache_resource anotation enables caching for Streamlit. | |
| """ | |
| return SentenceTransformer(EMBEDDING_MODEL) | |
| def load_dataSet(): | |
| """Initialize pretrained model for vectorizing. | |
| @st.cache_resource anotation enables caching for Streamlit. | |
| """ | |
| return load_dataset(DATASET , split="test[:1%]") | |
| def load_searchEngine(prompts, _model): | |
| """Initialize search engine and vectorize raw propmpts from dataset. | |
| @st.cache_resource anotation enables caching for Streamlit. | |
| Args: | |
| prompts: The sequence of raw prompts from the dataset. | |
| model: The model for vectorizing. | |
| """ | |
| return PromptSearchEngine(prompts, _model) | |
| model = load_model() | |
| dataset = load_dataSet() | |
| promptSearchEngine = load_searchEngine(dataset["Prompt"], model) | |
| with st.form("search_form"): | |
| st.write("Prompt Search Engine") | |
| query = st.text_area("Prompt to search") | |
| number = st.number_input("Number of similar prompts", value = 5, min_value=0, max_value=100) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| result = promptSearchEngine.most_similar(query, number) | |
| st.dataframe( | |
| result, | |
| use_container_width=True, | |
| column_config={ | |
| 1: st.column_config.NumberColumn( | |
| "Similarity", | |
| help="Range in [-1, 1] where 1 is max similarity, means that prompts are identical.", | |
| format= "%.4f" | |
| ), | |
| 2: st.column_config.TextColumn("Prompts", help="The simlar prompts"), | |
| }, | |
| ) | |