| from PIL import Image |
| import gradio as gr |
| import re |
| import pandas as pd |
| import joblib |
| import datetime |
| import matplotlib.pyplot as plt |
| from io import BytesIO |
| from nltk.tokenize import TreebankWordTokenizer |
| from nltk.stem import WordNetLemmatizer |
| from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS |
| import os |
| import time |
| import zipfile |
|
|
| |
| lda = joblib.load("lda_model.joblib") |
| vectorizer = joblib.load("vectorizer.joblib") |
| auto_labels = joblib.load("topic_labels.joblib") |
|
|
| |
| topic_summaries = { |
| "Politics & Gun Rights": "Discussions about government policies, laws, gun control, and rights.", |
| "Computing & Hardware": "Technical issues and terms related to computer hardware and drivers.", |
| "Programming & Software": "Programming terms, file handling, software output.", |
| "Sports & Games": "Topics related to teams, players, seasons, and matches.", |
| "Health & Medicine": "Diseases, treatment, healthcare, and medical facilities.", |
| "Religion & Philosophy": "Talks involving faith, belief systems, philosophical views.", |
| "Space & NASA": "Space exploration, NASA missions, satellites, and astronomy.", |
| "Cryptography & Security": "Discussions on encryption, digital security, and data protection.", |
| "Internet & Networking": "Terms around internet use, FTP, web versions, and networks.", |
| "Middle East Politics & Conflicts": "Topics involving Israel, Armenia, conflict regions." |
| } |
|
|
| |
| tokenizer = TreebankWordTokenizer() |
| lemmatizer = WordNetLemmatizer() |
|
|
| |
|
|
| def preprocess(text): |
| text = re.sub(r'\W+', ' ', text.lower()) |
| tokens = tokenizer.tokenize(text) |
| tokens = [lemmatizer.lemmatize(w) for w in tokens if w not in ENGLISH_STOP_WORDS and len(w) > 2 and w.isalpha()] |
| return ' '.join(tokens) |
|
|
| def get_topic_keywords(model, vectorizer, topic_idx, top_n=10): |
| feature_names = vectorizer.get_feature_names_out() |
| topic = model.components_[topic_idx] |
| top_indices = topic.argsort()[:-top_n - 1:-1] |
| return [feature_names[i] for i in top_indices] |
|
|
| def plot_topic_distribution(distribution, labels): |
| plt.figure(figsize=(8, 4)) |
| plt.bar(range(len(distribution)), distribution, tick_label=labels) |
| plt.xticks(rotation=45, ha="right") |
| plt.ylabel("Probability") |
| plt.title("Topic Distribution") |
| plt.tight_layout() |
| buf = BytesIO() |
| plt.savefig(buf, format="png") |
| plt.close() |
| buf.seek(0) |
| return Image.open(buf) |
|
|
| def save_prediction_file(text): |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"lda_prediction_{timestamp}.txt" |
| with open(filename, "w", encoding="utf-8") as f: |
| f.write(text) |
| return filename |
|
|
| def cleanup_old_predictions(directory=".", extension=".txt", max_age_minutes=10): |
| now = time.time() |
| max_age = max_age_minutes * 60 |
| for fname in os.listdir(directory): |
| if fname.endswith(extension) and fname.startswith("lda_prediction_"): |
| full_path = os.path.join(directory, fname) |
| if os.path.isfile(full_path) and (now - os.path.getmtime(full_path)) > max_age: |
| try: |
| os.remove(full_path) |
| except Exception as e: |
| print(f"Failed to delete {fname}: {e}") |
|
|
| def download_log(): |
| zip_filename = "lda_predictions_log.zip" |
| with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf: |
| zipf.write("lda_predictions_log.csv") |
| return zip_filename |
|
|
| def save_feedback(text, feedback): |
| timestamp = datetime.datetime.now().isoformat() |
| log_entry = pd.DataFrame([{ |
| "timestamp": timestamp, |
| "feedback": feedback, |
| "text_excerpt": text[:300].replace('\n', ' ') + "..." |
| }]) |
| feedback_log = "lda_feedback_log.csv" |
| log_entry.to_csv(feedback_log, mode='a', header=not os.path.exists(feedback_log), index=False) |
| return " Feedback recorded. Thank you!" |
|
|
| |
|
|
| def predict_topic(text_input, file_input): |
| cleanup_old_predictions() |
|
|
| if file_input is not None: |
| text = file_input.read().decode("utf-8") |
| elif text_input.strip(): |
| text = text_input |
| else: |
| return "Please provide input", None, None |
|
|
| cleaned = preprocess(text) |
| bow = vectorizer.transform([cleaned]) |
| topic_distribution = lda.transform(bow)[0] |
| dominant_topic = topic_distribution.argmax() |
| label = auto_labels.get(dominant_topic, f"Topic {dominant_topic+1}") |
| top_words = get_topic_keywords(lda, vectorizer, dominant_topic) |
| summary = topic_summaries.get(label, "No summary available.") |
|
|
| |
| confidence_threshold = 0.4 |
| if topic_distribution[dominant_topic] < confidence_threshold: |
| label += " ( Low confidence)" |
| summary = " The model is uncertain. Try providing more context or a longer input." |
|
|
| |
| timestamp = datetime.datetime.now().isoformat() |
| log_entry = pd.DataFrame([{ |
| "timestamp": timestamp, |
| "predicted_topic": label, |
| "dominant_topic_index": dominant_topic, |
| "top_words": ", ".join(top_words), |
| "text_excerpt": text[:300].replace('\n', ' ') + "..." |
| }]) |
| log_path = "lda_predictions_log.csv" |
| log_entry.to_csv(log_path, mode='a', header=not os.path.exists(log_path), index=False) |
|
|
| chart = plot_topic_distribution(topic_distribution, [auto_labels.get(i, f"Topic {i+1}") for i in range(len(topic_distribution))]) |
|
|
| result = f" **Predicted Topic:** {label}\n\n" |
| result += f" **Summary:** {summary}\n\n" |
| result += f" **Top Words:** {', '.join(top_words)}\n\n" |
| result += " **Topic Distribution:**\n" |
| for idx, prob in enumerate(topic_distribution): |
| tlabel = auto_labels.get(idx, f"Topic {idx+1}") |
| result += f"{tlabel}: {prob:.3f}\n" |
|
|
| prediction_file = save_prediction_file(result) |
| return result, chart, prediction_file |
|
|
| |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## Topic Modeling with LDA") |
| gr.Markdown("Upload a `.txt` file or paste in text. See predicted topic, keywords, and a chart.") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| text_input = gr.Textbox(lines=10, label=" Paste Text") |
| file_input = gr.File(label=" Or Upload a .txt File", file_types=[".txt"]) |
| predict_btn = gr.Button(" Predict Topic") |
| download_btn = gr.Button("⬇ Download All Logs") |
|
|
| feedback_input = gr.Radio( |
| choices=["Accurate", " Inaccurate", "Unclear"], |
| label=" Was this prediction useful?", |
| interactive=True |
| ) |
| feedback_btn = gr.Button("Submit Feedback") |
| feedback_output = gr.Textbox(visible=False) |
|
|
| with gr.Column(): |
| output_text = gr.Textbox(label=" Prediction Result") |
| output_chart = gr.Image(type="pil", label=" Topic Distribution") |
| download_prediction = gr.File(label="⬇ Download This Prediction") |
|
|
| predict_btn.click( |
| fn=predict_topic, |
| inputs=[text_input, file_input], |
| outputs=[output_text, output_chart, download_prediction] |
| ) |
|
|
| download_btn.click(fn=download_log, outputs=[gr.File()]) |
|
|
| feedback_btn.click( |
| fn=save_feedback, |
| inputs=[text_input, feedback_input], |
| outputs=[feedback_output] |
| ) |
|
|
| demo.launch() |
|
|