from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, ValidationError from typing import List, Optional from torch import cuda from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import login from dotenv import load_dotenv import os import uvicorn import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Required for access to a gated model load_dotenv() hf_token = os.getenv("HF_TOKEN", None) if hf_token is not None: login(token=hf_token) # Configurable model identifier model_name = os.getenv("HF_MODEL", "swiss-ai/Apertus-8B-Instruct-2509") # Keep data in session model = None tokenizer = None class TextInput(BaseModel): text: str min_length: int = 3 # Apertus by default supports a context length up to 65,536 tokens. max_length: int = 65536 class ModelResponse(BaseModel): text: str confidence: float processing_time: float class ChatMessage(BaseModel): role: str content: str class Completion(BaseModel): model: str = "apertus" messages: List[ChatMessage] max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.1 top_p: Optional[float] = 0.9 @asynccontextmanager async def lifespan(app: FastAPI): """Load the transformer model on startup""" global model, tokenizer try: logger.info(f"Loading model: {model_name}") # Automatically select device based on availability device = "cuda" if cuda.is_available() else "cpu" # load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically splits model across CPU/GPU low_cpu_mem_usage=True, # Avoids unnecessary CPU memory duplication offload_folder="offload", # Temporary offload to disk ) #.to(device) logger.info(f"Model loaded successfully! ({device})") except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Release resources when the app is stopped yield del model del tokenizer cuda.empty_cache() # Setup our app app = FastAPI( title="Apertus API", description="REST API for serving Apertus models via Hugging Face transformers", version="0.1.0", docs_url="/", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def fit_to_length(text, min_length=3, max_length=100): """Truncate text if too long.""" text = text[:max_length] if len(text) == max_length: logger.warning("Warning: text truncated") if len(text) < min_length: logger.warning("Warning: empty text, aborting") return None return text def get_model_reponse(messages_think): """Process the text content.""" # Prepare the model input text = tokenizer.apply_chat_template( messages_think, tokenize=False, add_generation_prompt=True, top_p=0.9, temperature=0.8, ) model_inputs = tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to(model.device) # Generate the output generated_ids = model.generate( **model_inputs, max_new_tokens=512 ) # Get and decode the output output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] # Return just the text return tokenizer.decode(output_ids, skip_special_tokens=True) @app.post("/v1/models/apertus") async def completion(data: Completion): """Generate an OpenAPI-style completion""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: result = get_model_reponse(data) return { "choices": [ { "text": result, "_index": 0, "logprobs": None, "finish_reason": "length" } ], "usage": { "prompt_tokens": len(text), "completion_tokens": len(result), "total_tokens": len(text) + len(result) } } except ValidationError as e: raise HTTPException(status_code=400, detail="Invalid input data") from e @app.get("/predict", response_model=ModelResponse) async def predict(q: str): """Generate a model response for input text""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: import time start_time = time.time() input_data = TextInput(text=q) text = fit_to_length(input_data.text, input_data.min_length, input_data.max_length) messages_think = [ {"role": "user", "content": text} ] result = get_model_reponse(messages_think) # Checkpoint processing_time = time.time() - start_time return ModelResponse( text=result, #['label'], confidence=0, #result['score'], processing_time=processing_time ) except HTTPException as e: logger.error(f"Evaluation error: {e}") raise HTTPException(status_code=500, detail="Evaluation failed") @app.get("/health") async def health_check(): """Health check and basic configuration""" return { "status": "healthy", "model_loaded": model is not None, "gpu_available": cuda.is_available() } if __name__=='__main__': uvicorn.run('app:app', reload=True)