from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, ValidationError from typing import List, Optional from torch import cuda from transformers import AutoModelForCausalLM, AutoTokenizer from hashlib import sha256 from huggingface_hub import login from dotenv import load_dotenv from datetime import datetime import os import uvicorn import time import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Required for access to a gated model load_dotenv() hf_token = os.getenv("HF_TOKEN", None) if hf_token is not None: login(token=hf_token) # Configurable model identifier model_name = os.getenv("HF_MODEL", "swiss-ai/Apertus-8B-Instruct-2509") # Configure max tokens MAX_NEW_TOKENS = 4096 # Load base prompt from a text file system_prompt = "" with open('system_prompt.md', 'r') as file: system_prompt = file.read() # Keep data in session model = None tokenizer = None class TextInput(BaseModel): text: str = "" min_length: int = 3 # Apertus by default supports a context length up to 65,536 tokens. max_length: int = 65536 class ModelResponse(BaseModel): text: str confidence: float processing_time: float class ChatMessage(BaseModel): role: str = "user" content: str = "" class Completion(BaseModel): model: str = "apertus" messages: List[ChatMessage] max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.1 top_p: Optional[float] = 0.9 @asynccontextmanager async def lifespan(app: FastAPI): """Load the transformer model on startup""" global model, tokenizer try: logger.info(f"Loading model: {model_name}") # Automatically select device based on availability device = "cuda" if cuda.is_available() else "cpu" # load the tokenizer and the model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", # Automatically splits model across CPU/GPU low_cpu_mem_usage=True, # Avoids unnecessary CPU memory duplication offload_folder="offload", # Temporary offload to disk ) #.to(device) logger.info(f"Model loaded successfully! ({device})") except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Release resources when the app is stopped yield del model del tokenizer cuda.empty_cache() # Setup our app app = FastAPI( title="Apertus API", description="REST API for serving Apertus models via Hugging Face transformers", version="0.1.0", docs_url="/", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def fit_to_length(text, min_length=3, max_length=100): """Truncate text if too long.""" text = text[:max_length] if len(text) == max_length: logger.warning("Warning: text truncated") if len(text) < min_length: logger.warning("Warning: empty text, aborting") return None return text def get_completion_text(messages_think: List[ChatMessage]): txt = "" for cm in messages_think: txt = " ".join((txt, cm.content)) return txt def get_message_id(txt: str): return sha256(str(txt).encode()).hexdigest() def get_model_reponse(messages_think: List[ChatMessage]): """Process the text content.""" # Apply the system template has_system = False for m in messages_think: if m.role == 'system': has_system = True if not has_system: cm = ChatMessage(role='system', content=system_prompt) messages_think.insert(0, cm) #print(messages_think) # Prepare the model input text = tokenizer.apply_chat_template( messages_think, tokenize=False, add_generation_prompt=True, top_p=0.9, temperature=0.8, ) model_inputs = tokenizer( [text], return_tensors="pt", add_special_tokens=False ).to(model.device) # Generate the output generated_ids = model.generate( **model_inputs, max_new_tokens=MAX_NEW_TOKENS ) # Get and decode the output output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] # Decode the text message return tokenizer.decode(output_ids, skip_special_tokens=True) @app.post("/v1/models/apertus") async def completion(data: Completion): """Generate an OpenAPI-style completion""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: mt = data.messages text = get_completion_text(mt) result = get_model_reponse(mt) # Standard formatted object return { "id": get_message_id(text), "object": "chat.completion", "created": time.time(), "model": data.model, "choices": [{ "message": ChatMessage(role="assistant", content=result) }], "usage": { "prompt_tokens": len(text), "completion_tokens": len(result), "total_tokens": len(text) + len(result) } } except Exception as e: logger.warning(e) raise HTTPException(status_code=400, detail="Could not process") from e @app.get("/predict", response_model=ModelResponse) async def predict(q: str): """Generate a model response for input text""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: start_time = time.time() input_data = TextInput(text=q) text = fit_to_length(input_data.text, input_data.min_length, input_data.max_length) messages_think = [ {"role": "user", "content": text} ] result = get_model_reponse(messages_think) # Checkpoint processing_time = time.time() - start_time return ModelResponse( text=result, #['label'], confidence=0, #result['score'], processing_time=processing_time ) except Exception as e: logger.warning(e) raise HTTPException(status_code=500, detail="Evaluation failed") @app.get("/health") async def health_check(): """Health check and basic configuration""" return { "status": "healthy", "model_loaded": model is not None, "gpu_available": cuda.is_available() } if __name__=='__main__': uvicorn.run('app:app', reload=True)