Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import pipeline | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| #from PIL import Image | |
| #pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned") | |
| #text = st.text_area('Please type/copy/paste the Dutch article') | |
| #labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis' | |
| # 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech'] | |
| #if text: | |
| # out = pipe(text) | |
| # st.json(out) | |
| # load tokenizer and model, create trainer | |
| #model_name = "RuudVelo/dutch_news_classifier_bert_finetuned" | |
| #tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| #model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| #trainer = Trainer(model=model) | |
| #print(filename, type(filename)) | |
| #print(filename.name) | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
| #from transformers import BertTokenizer | |
| tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned") | |
| # Title | |
| st.title("Dutch news article classification") | |
| st.write("This app classifies a Dutch news article into one of 9 pre-defined* article categories") | |
| #image = Image.open('dataset-cover_articles.jpg') | |
| st.image('dataset-cover_articles.jpeg', width=150) | |
| text = st.text_area('Please type/copy/paste text of the Dutch article and click Submit') | |
| #if text: | |
| # encoding = tokenizer(text, return_tensors="pt") | |
| # outputs = model(**encoding) | |
| # predictions = outputs.logits.argmax(-1) | |
| # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| ## fig = plt.figure() | |
| # ax = fig.add_axes([0,0,1,1]) | |
| # labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', | |
| # 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] | |
| # probs_plot = probabilities[0].cpu().detach().numpy() | |
| # ax.barh(labels_plot,probs_plot ) | |
| # st.pyplot(fig) | |
| #input = st.text_input('Context') | |
| if st.button('Submit'): | |
| with st.spinner('Generating a response...'): | |
| encoding = tokenizer(text, return_tensors="pt") | |
| outputs = model(**encoding) | |
| predictions = outputs.logits.argmax(-1) | |
| number = predictions[0].cpu().detach().numpy() | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| fig = plt.figure(figsize=(10,4)) | |
| ax = fig.add_axes([0,0,1,1]) | |
| labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis', | |
| 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech'] | |
| probs_plot = probabilities[0].cpu().detach().numpy()*100 | |
| ax.barh(labels_plot,probs_plot) | |
| ax.set_title("Predicted article category probability", fontsize=20) | |
| ax.set_xlabel("Probability") | |
| ax.set_ylabel("Predicted category") | |
| st.pyplot(fig) | |
| st.write('The predicted category is: **{}** with a probability of: **{:.1f}%**'.format(labels_plot[number],(probs_plot[predictions])*1)) | |
| # output = genQuestion(option, input) | |
| # print(output) | |
| # st.write(output) | |
| #encoding = tokenizer(text, return_tensors="pt") | |
| #import numpy as np | |
| st.write("The pre-defined categories are Binnenland, Buitenland, Cultuur & Media, Economie , Koningshuis, Opmerkelijk, Politiek, 'Regionaal nieuws en Tech") | |
| st.write("The model for this app has been trained using data from Dutch news articles published by NOS. For more information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles") | |
| #st.write('\n') | |
| st.write('The model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_classifier_bert_finetuned') | |