Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import torch | |
| # Load pre-trained model and tokenizer | |
| model_name = "distilbert-base-uncased" | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Define a function to search for Samba songs | |
| def search_samba(query): | |
| # Tokenize the query | |
| inputs = tokenizer(query, return_tensors="pt") | |
| # Get the model's output | |
| outputs = model(**inputs) | |
| # Get the predicted class probabilities | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| # Return the top 5 Samba song matches | |
| top_5 = torch.topk(probs, k=5) | |
| return top_5.indices, top_5.values | |
| # Create a Streamlit app | |
| st.title("Samba Search") | |
| # Get user input | |
| query = st.text_input("Enter a song title or artist") | |
| # Search for Samba songs | |
| if st.button("Search"): | |
| indices, values = search_samba(query) | |
| # Display the top 5 matches | |
| st.write("Top 5 Samba Song Matches:") | |
| for i, (index, value) in enumerate(zip(indices[0], values[0])): | |
| st.write(f"{i+1}. Song {index.item()} - Probability: {value.item():.2f}") | |