Spaces:
Runtime error
Runtime error
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from torch import cuda | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import login | |
| from dotenv import load_dotenv | |
| import os | |
| import uvicorn | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Required for access to a gated model | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN", None) | |
| if hf_token is not None: | |
| login(token=hf_token) | |
| # Configurable model identifier | |
| model_name = os.getenv("HF_MODEL", "swiss-ai/Apertus-8B-Instruct-2509") | |
| # Keep data in session | |
| model = None | |
| tokenizer = None | |
| class TextInput(BaseModel): | |
| text: str | |
| min_length: int = 3 | |
| # Apertus by default supports a context length up to 65,536 tokens. | |
| max_length: int = 65536 | |
| class ModelResponse(BaseModel): | |
| text: str | |
| confidence: float | |
| processing_time: float | |
| async def lifespan(app: FastAPI): | |
| """Load the transformer model on startup""" | |
| global model, tokenizer | |
| try: | |
| logger.info(f"Loading model: {model_name}") | |
| # Automatically select device based on availability | |
| device = "cuda" if cuda.is_available() else "cpu" | |
| # load the tokenizer and the model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", # Automatically splits model across CPU/GPU | |
| low_cpu_mem_usage=True, # Avoids unnecessary CPU memory duplication | |
| offload_folder="offload", # Temporary offload to disk | |
| ) | |
| #.to(device) | |
| logger.info(f"Model loaded successfully! ({device})") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise e | |
| # Release resources when the app is stopped | |
| yield | |
| del model | |
| del tokenizer | |
| cuda.empty_cache() | |
| # Setup our app | |
| app = FastAPI( | |
| title="Apertus API", | |
| description="REST API for serving Apertus models via Hugging Face transformers", | |
| version="0.1.0", | |
| docs_url="/", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def predict(q: str): | |
| """Generate a model response for input text""" | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| input_data = TextInput(text=q) | |
| # Truncate text if too long | |
| text = input_data.text[:input_data.max_length] | |
| if len(text) == input_data.max_length: | |
| logger.warning("Warning: text truncated") | |
| if len(text) < input_data.min_length: | |
| logger.warning("Warning: empty text, aborting") | |
| return None | |
| # Prepare the model input | |
| messages_think = [ | |
| {"role": "user", "content": text} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages_think, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| top_p=0.9, | |
| temperature=0.8, | |
| ) | |
| model_inputs = tokenizer( | |
| [text], | |
| return_tensors="pt", | |
| add_special_tokens=False | |
| ).to(model.device) | |
| # Generate the output | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=512 | |
| ) | |
| # Get and decode the output | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] | |
| result = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| # Checkpoint | |
| processing_time = time.time() - start_time | |
| return ModelResponse( | |
| text=result, #['label'], | |
| confidence=0, #result['score'], | |
| processing_time=processing_time | |
| ) | |
| except HTTPException as e: | |
| logger.error(f"Evaluation error: {e}") | |
| raise HTTPException(status_code=500, detail="Evaluation failed") | |
| async def health_check(): | |
| """Health check and basic configuration""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "gpu_available": cuda.is_available() | |
| } | |
| if __name__=='__main__': | |
| uvicorn.run('app:app', reload=True) | |