Spaces:
Runtime error
Runtime error
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, ValidationError | |
| from typing import List, Optional | |
| from torch import cuda | |
| from transformers import ( | |
| AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| ) | |
| from hashlib import sha256 | |
| from huggingface_hub import login | |
| from dotenv import load_dotenv | |
| from datetime import datetime | |
| import os | |
| import uvicorn | |
| import time | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Required for access to a gated model | |
| load_dotenv() | |
| hf_token = os.getenv("HF_TOKEN", None) | |
| if hf_token is not None: | |
| login(token=hf_token) | |
| # Configurable model identifier | |
| model_name = os.getenv("HF_MODEL", "swiss-ai/Apertus-8B-Instruct-2509") | |
| model_quantization = int(os.getenv("QUANTIZE", 0)) # 8, 4, 0=default | |
| # Configure max tokens | |
| MAX_NEW_TOKENS = 4096 | |
| # Load base prompt from a text file | |
| system_prompt = None | |
| if int(os.getenv("USE_SYSTEM_PROMPT", 1)): | |
| with open('system_prompt.md', 'r') as file: | |
| system_prompt = file.read() | |
| # Keep data in session | |
| model = None | |
| tokenizer = None | |
| class TextInput(BaseModel): | |
| text: str = "" | |
| min_length: int = 3 | |
| # Apertus by default supports a context length up to 65,536 tokens. | |
| max_length: int = 65536 | |
| class ModelResponse(BaseModel): | |
| text: str | |
| confidence: float | |
| processing_time: float | |
| class ChatMessage(BaseModel): | |
| role: str = "user" | |
| content: str = "" | |
| class Completion(BaseModel): | |
| model: str = "apertus" | |
| messages: List[ChatMessage] | |
| max_tokens: Optional[int] = 512 | |
| temperature: Optional[float] = 0.1 | |
| top_p: Optional[float] = 0.9 | |
| async def lifespan(app: FastAPI): | |
| """Load the transformer model on startup""" | |
| global model, tokenizer | |
| try: | |
| logger.info(f"Loading model: {model_name}") | |
| # Automatically select device based on availability | |
| device = "cuda" if cuda.is_available() else "cpu" | |
| # load the tokenizer and the model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Use a quantization setting | |
| bnb_config = None | |
| if model_quantization == 8: | |
| bnb_config = BitsAndBytesConfig(load_in_8bit=True) | |
| elif model_quantization == 4: | |
| bnb_config = BitsAndBytesConfig(load_in_4bit=True) | |
| if bnb_config is not None: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", # Automatically splits model across CPU/GPU | |
| offload_folder="offload", # Temporary offload to disk | |
| low_cpu_mem_usage=True, # Avoids unnecessary CPU memory duplication | |
| quantization_config=bnb_config, # To reduce memory and overhead | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", # Automatically splits model across CPU/GPU | |
| offload_folder="offload", # Temporary offload to disk | |
| ) | |
| logger.info(f"Model loaded successfully! ({device})") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise e | |
| # Release resources when the app is stopped | |
| yield | |
| del model | |
| del tokenizer | |
| cuda.empty_cache() | |
| # Setup our app | |
| app = FastAPI( | |
| title="Apertus API", | |
| description="REST API for serving Apertus models via Hugging Face transformers", | |
| version="0.1.0", | |
| docs_url="/", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def fit_to_length(text, min_length=3, max_length=100): | |
| """Truncate text if too long.""" | |
| text = text[:max_length] | |
| if len(text) == max_length: | |
| logger.warning("Warning: text truncated") | |
| if len(text) < min_length: | |
| logger.warning("Warning: empty text, aborting") | |
| return None | |
| return text | |
| def get_completion_text(messages_think: List[ChatMessage]): | |
| txt = "" | |
| for cm in messages_think: | |
| txt = " ".join((txt, cm.content)) | |
| return txt | |
| def get_message_id(txt: str): | |
| return sha256(str(txt).encode()).hexdigest() | |
| def get_model_reponse(messages_think: List[ChatMessage]): | |
| """Process the text content.""" | |
| # Apply the system template | |
| has_system = False | |
| for m in messages_think: | |
| if m.role == 'system': | |
| has_system = True | |
| if not has_system and system_prompt: | |
| cm = ChatMessage(role='system', content=system_prompt) | |
| messages_think.insert(0, cm) | |
| print(messages_think) | |
| # Prepare the model input | |
| text = tokenizer.apply_chat_template( | |
| messages_think, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| top_p=0.9, | |
| temperature=0.8, | |
| ) | |
| model_inputs = tokenizer( | |
| [text], | |
| return_tensors="pt", | |
| add_special_tokens=False | |
| ).to(model.device) | |
| # Generate the output | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=MAX_NEW_TOKENS | |
| ) | |
| # Get and decode the output | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :] | |
| # Decode the text message | |
| return tokenizer.decode(output_ids, skip_special_tokens=True) | |
| async def completion(data: Completion): | |
| """Generate an OpenAPI-style completion""" | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| mt = data.messages | |
| text = get_completion_text(mt) | |
| result = get_model_reponse(mt) | |
| # Standard formatted object | |
| return { | |
| "id": get_message_id(text), | |
| "object": "chat.completion", | |
| "created": time.time(), | |
| "model": data.model, | |
| "choices": [{ | |
| "message": ChatMessage(role="assistant", content=result) | |
| }], | |
| "usage": { | |
| "prompt_tokens": len(text), | |
| "completion_tokens": len(result), | |
| "total_tokens": len(text) + len(result) | |
| } | |
| } | |
| except Exception as e: | |
| logger.warning(e) | |
| raise HTTPException(status_code=400, detail="Could not process") from e | |
| async def predict(q: str): | |
| """Generate a model response for input text""" | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| start_time = time.time() | |
| input_data = TextInput(text=q) | |
| text = fit_to_length(input_data.text, input_data.min_length, input_data.max_length) | |
| messages_think = [ | |
| {"role": "user", "content": text} | |
| ] | |
| result = get_model_reponse(messages_think) | |
| # Checkpoint | |
| processing_time = time.time() - start_time | |
| return ModelResponse( | |
| text=result, #['label'], | |
| confidence=0, #result['score'], | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.warning(e) | |
| raise HTTPException(status_code=500, detail="Evaluation failed") | |
| async def health_check(): | |
| """Health check and basic configuration""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "gpu_available": cuda.is_available() | |
| } | |
| if __name__=='__main__': | |
| uvicorn.run('app:app', reload=True) | |