| from transformers import pipeline, Pipeline, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| import torch | |
| class SpanClassificationPipeline(Pipeline): | |
| def __init__(self, model, tokenizer, device="cpu", **kwargs): | |
| super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def _sanitize_parameters(self, **kwargs): | |
| return {}, kwargs, {} | |
| def preprocess(self, inputs): | |
| return self.tokenizer(inputs, return_tensors="pt").to(self.device) | |
| def _forward(self, model_inputs): | |
| with torch.no_grad(): | |
| outputs = self.model(**model_inputs) | |
| return outputs | |
| def postprocess(self, model_outputs): | |
| logits = model_outputs.logits | |
| return int(torch.argmax(logits, dim=1).item()) | |
| PIPELINE_REGISTRY.register_pipeline( | |
| task="spancnn-classification", | |
| pipeline_class=SpanClassificationPipeline, | |
| pt_model=AutoModelForSequenceClassification, | |
| ) | |