| from typing import List | |
| from transformers import PreTrainedModel | |
| from transformers import PretrainedConfig | |
| from huggingface_hub import hf_hub_download | |
| from .QBModelConfig import QBModelConfig | |
| from .qbmodel import QuizBowlModel | |
| REPO_ID = "Backedman/TriviaAnsweringMachineREAL" | |
| FILENAME = "models/Mythology_tfidf.pkl" | |
| class QBModelWrapper(PreTrainedModel): | |
| config_class= QBModelConfig | |
| config = QBModelConfig | |
| print("jdkalf;jdskl") | |
| hf_hub_download(repo_id=REPO_ID, filename='tfidf_model.py', local_dir='.') | |
| hf_hub_download(repo_id=REPO_ID, filename='question_categorizer.py', local_dir='.') | |
| hf_hub_download(repo_id=REPO_ID, filename='models/categorizer', local_dir='.') | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = QuizBowlModel() | |
| self.predict = self.model.predict | |
| def forward(self, question): | |
| output = self.model.predict(question) | |
| return output[0] |