import torch import gradio as gr from huggingface_hub import hf_hub_download from transformers import AutoTokenizer from hf_model import BERT_FFNN, BertFFNNConfig from torch.nn.functional import sigmoid LABELS = ["anger", "fear", "joy", "sadness", "surprise"] DEVICE = "cuda" if torch.cuda.is_available() else "cpu" config = BertFFNNConfig.from_pretrained("NeuralNest05/emo-detector") model = BERT_FFNN(config) model_path = hf_hub_download( repo_id="NeuralNest05/emo-detector", filename="pytorch_model.bin" ) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.to(DEVICE) model.eval() tokenizer = AutoTokenizer.from_pretrained("NeuralNest05/emo-detector") def predict_texts(texts, threshold=0.5): if isinstance(texts, str): texts = [texts] encodings = tokenizer( texts, truncation=True, padding=True, max_length=128, return_tensors="pt" ) input_ids = encodings["input_ids"].to(DEVICE) attention_mask = encodings["attention_mask"].to(DEVICE) with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask) probs = sigmoid(logits) binary_preds = (probs > threshold).int().cpu().numpy() results = [] for b in binary_preds: result = [LABELS[i] for i, v in enumerate(b) if v == 1] if not result: result = ["None"] results.append(result) return results if len(results) > 1 else results[0] with gr.Blocks() as demo: gr.Markdown("# Emo-detector") gr.Markdown("Enter a single text or multiple texts separated by line breaks.") input_text = gr.Textbox( label="Input Text", placeholder="Type a sentence or multiple sentences (one per line)...", lines=5 ) threshold_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Prediction Threshold" ) output = gr.Textbox(label="Predicted Emotions") def wrapper(text, threshold): texts = [line.strip() for line in text.split("\n") if line.strip()] if not texts: return "Please enter at least one non-empty sentence." preds = predict_texts(texts, threshold) if isinstance(preds[0], list): return "\n".join([f"{t}: {e}" for t, e in zip(texts, preds)]) else: return f"{texts[0]}: {preds}" submit_btn = gr.Button("Predict") submit_btn.click(wrapper, inputs=[input_text, threshold_slider], outputs=output) demo.launch()