eurovoc / app.py
ZseeBrz's picture
Update app.py
cfbb0e6 verified
# -*- coding: utf-8 -*-
from huggingface_hub import snapshot_download
snapshot_download(repo_id="apapagi/eurovoc_2025", local_dir = "")
import torch
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
from transformers import AutoModel
from huggingface_hub import PyTorchModelHubMixin
from typing import Dict, List, Any
import numpy as np
import pickle
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer
import torch
import textract
import gradio as gr
#define neural network for forward pass
class EurovocTagger(pl.LightningModule, PyTorchModelHubMixin):
def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8):
super().__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
self.dropout = nn.Dropout(p=0.2)
self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes)
self.criterion = nn.BCELoss()
self.lr = lr
self.eps = eps
def forward(self, input_ids, attention_mask, labels=None):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.dropout(output.pooler_output)
output = self.classifier1(output)
output = torch.sigmoid(output)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
BERT_MODEL_NAME = "EuropeanParliament/EUBERT"
EUROVOC_MODEL_NAME = "apapagi/eurovoc_2025"
MAX_LEN = 512
TEXT_MAX_LEN = MAX_LEN * 50
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
#define classifier endpoints anb load weights
class EndpointHandler:
mlb = MultiLabelBinarizer()
def __init__(self, path="."):
self.mlb = pickle.load(open(f"{path}/mlb.pickle", "rb"))
print ("pickle loaded")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (torch.cuda.is_available())
self.model = EurovocTagger.from_pretrained(path,
bert_model_name=BERT_MODEL_NAME,
n_classes=len(self.mlb.classes_),
#map_location=self.device)
#need to comment this out, otherwise it errors out
)
print ("model loaded")
self.model.eval()
self.model.freeze()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
text = data.pop("inputs", data)
topk = data.pop("topk", 5)
threshold = data.pop("threshold", 0.16)
debug = data.pop("debug", False)
prediction = self.get_prediction(text)
results = [{"label": label, "score": float(score)} for label, score in
zip(self.mlb.classes_, prediction[0].tolist())]
results = sorted(results, key=lambda x: x["score"], reverse=True)
results = [r for r in results if r["score"] > threshold]
results = results[:topk]
if debug:
return {"results": results, "values": prediction, "input": text}
else:
return {"results": results}
def get_prediction(self, text):
# split text into chunks of MAX_LEN and get average prediction for each chunk
chunks = [text[i:i + MAX_LEN] for i in range(0, min(len(text), TEXT_MAX_LEN), MAX_LEN)]
predictions = [self._get_prediction(chunk) for chunk in chunks]
predictions = np.array(predictions).mean(axis=0)
return predictions
def _get_prediction(self, text):
item = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=MAX_LEN,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt')
item.to(self.device)
_, prediction = self.model(item["input_ids"], item["attention_mask"])
prediction = prediction.cpu().detach().numpy()
print(text, prediction)
return prediction
#initialize model
endpoint = EndpointHandler(path=".")
#define app logic
def eurovoc_app(text, no_of_classifiers, threshold, file):
if file is None:
payload = {
"inputs": text,
"topk": no_of_classifiers,
"threshold": threshold}
prediction = endpoint(payload)
return prediction
try:
# Extract text from the uploaded file
text = textract.process(file.name).decode('utf-8')
payload = {
"inputs": text,
"topk": no_of_classifiers,
"threshold": threshold}
prediction = endpoint(payload)
return prediction
except Exception as e:
return f"Error processing file: {e}"
#define app interface
eurovoc_interface = gr.Interface(
fn=eurovoc_app,
inputs=[gr.Textbox(label="Enter text or upload file"), gr.Slider(minimum=2, maximum=12, step=1, label="Number of Classifiers"), gr.Slider(minimum=0, maximum=1, label="Weight thresholds"), gr.File(label="Upload your file")],
outputs=gr.Textbox(lines=10, label="Prediction Results"),
allow_flagging=None,
)
#launch app
eurovoc_interface.launch(share=True)