likhonsheikhdev's picture
Upload folder using huggingface_hub
f238f35 verified
raw
history blame
17.5 kB
"""
Docker Model Runner - Anthropic API Compatible
Full compatibility with Anthropic Messages API format
Optimized for: 2 vCPU, 16GB RAM
"""
from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import Optional, List, Union, Literal, Any, Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from datetime import datetime
from contextlib import asynccontextmanager
import uuid
import time
import json
import asyncio
# CPU-optimized lightweight models
GENERATOR_MODEL = os.getenv("GENERATOR_MODEL", "distilgpt2")
MODEL_DISPLAY_NAME = os.getenv("MODEL_NAME", "MiniMax-M2")
# Set CPU threading
torch.set_num_threads(2)
# Global model cache
models = {}
def load_models():
"""Pre-load models for faster inference"""
global models
print("Loading models for CPU inference...")
models["tokenizer"] = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
models["model"] = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL)
models["model"].eval()
if models["tokenizer"].pad_token is None:
models["tokenizer"].pad_token = models["tokenizer"].eos_token
print("✅ All models loaded successfully!")
@asynccontextmanager
async def lifespan(app: FastAPI):
load_models()
yield
models.clear()
app = FastAPI(
title="Docker Model Runner",
description="Anthropic API Compatible Endpoint",
version="1.0.0",
lifespan=lifespan
)
# ============== Anthropic API Models ==============
class TextBlock(BaseModel):
type: Literal["text"] = "text"
text: str
class ThinkingBlock(BaseModel):
type: Literal["thinking"] = "thinking"
thinking: str
class ToolUseBlock(BaseModel):
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: Dict[str, Any]
class ToolResultContent(BaseModel):
type: Literal["tool_result"] = "tool_result"
tool_use_id: str
content: Union[str, List[TextBlock]]
is_error: Optional[bool] = False
class ImageSource(BaseModel):
type: Literal["base64", "url"]
media_type: Optional[str] = None
data: Optional[str] = None
url: Optional[str] = None
class ImageBlock(BaseModel):
type: Literal["image"] = "image"
source: ImageSource
ContentBlock = Union[TextBlock, ThinkingBlock, ToolUseBlock, ToolResultContent, ImageBlock, str]
class MessageParam(BaseModel):
role: Literal["user", "assistant"]
content: Union[str, List[ContentBlock]]
class ToolInputSchema(BaseModel):
type: str = "object"
properties: Optional[Dict[str, Any]] = None
required: Optional[List[str]] = None
class Tool(BaseModel):
name: str
description: str
input_schema: ToolInputSchema
class ToolChoice(BaseModel):
type: Literal["auto", "any", "tool"] = "auto"
name: Optional[str] = None
class ThinkingConfig(BaseModel):
type: Literal["enabled", "disabled"] = "disabled"
budget_tokens: Optional[int] = None
class Metadata(BaseModel):
user_id: Optional[str] = None
class AnthropicRequest(BaseModel):
model: str = "MiniMax-M2"
messages: List[MessageParam]
max_tokens: int = 1024
temperature: Optional[float] = Field(default=1.0, gt=0.0, le=1.0)
top_p: Optional[float] = Field(default=1.0, gt=0.0, le=1.0)
top_k: Optional[int] = None # Ignored
stop_sequences: Optional[List[str]] = None # Ignored
stream: Optional[bool] = False
system: Optional[Union[str, List[TextBlock]]] = None
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = None
metadata: Optional[Metadata] = None
thinking: Optional[ThinkingConfig] = None
service_tier: Optional[str] = None # Ignored
class Usage(BaseModel):
input_tokens: int
output_tokens: int
cache_creation_input_tokens: Optional[int] = 0
cache_read_input_tokens: Optional[int] = 0
class AnthropicResponse(BaseModel):
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: List[Union[TextBlock, ThinkingBlock, ToolUseBlock]]
model: str
stop_reason: Optional[Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]] = "end_turn"
stop_sequence: Optional[str] = None
usage: Usage
# Streaming Event Models
class StreamEvent(BaseModel):
type: str
index: Optional[int] = None
content_block: Optional[Dict[str, Any]] = None
delta: Optional[Dict[str, Any]] = None
message: Optional[Dict[str, Any]] = None
usage: Optional[Dict[str, Any]] = None
# ============== Helper Functions ==============
def extract_text_from_content(content: Union[str, List[ContentBlock]]) -> str:
"""Extract text from content which may be string or list of blocks"""
if isinstance(content, str):
return content
texts = []
for block in content:
if isinstance(block, str):
texts.append(block)
elif hasattr(block, 'text'):
texts.append(block.text)
elif hasattr(block, 'thinking'):
texts.append(block.thinking)
elif isinstance(block, dict):
if block.get('type') == 'text':
texts.append(block.get('text', ''))
elif block.get('type') == 'thinking':
texts.append(block.get('thinking', ''))
return " ".join(texts)
def format_system_prompt(system: Optional[Union[str, List[TextBlock]]]) -> str:
"""Format system prompt from string or list of blocks"""
if system is None:
return ""
if isinstance(system, str):
return system
return " ".join([block.text for block in system if hasattr(block, 'text')])
def format_messages_to_prompt(messages: List[MessageParam], system: Optional[Union[str, List[TextBlock]]] = None) -> str:
"""Convert chat messages to a single prompt string"""
prompt_parts = []
system_text = format_system_prompt(system)
if system_text:
prompt_parts.append(f"System: {system_text}\n\n")
for msg in messages:
role = msg.role
content_text = extract_text_from_content(msg.content)
if role == "user":
prompt_parts.append(f"Human: {content_text}\n\n")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content_text}\n\n")
prompt_parts.append("Assistant:")
return "".join(prompt_parts)
def generate_text(prompt: str, max_tokens: int, temperature: float, top_p: float) -> tuple:
"""Generate text and return (text, input_tokens, output_tokens)"""
tokenizer = models["tokenizer"]
model = models["model"]
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
input_tokens = inputs["input_ids"].shape[1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=min(max_tokens, 256), # Limit for CPU
temperature=temperature if temperature > 0 else 1.0,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_tokens = outputs[0][input_tokens:]
output_tokens = len(generated_tokens)
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text.strip(), input_tokens, output_tokens
async def generate_stream(prompt: str, max_tokens: int, temperature: float, top_p: float, message_id: str, model_name: str):
"""Generate streaming response in Anthropic SSE format"""
tokenizer = models["tokenizer"]
model = models["model"]
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
input_tokens = inputs["input_ids"].shape[1]
# Send message_start event
message_start = {
"type": "message_start",
"message": {
"id": message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": model_name,
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": input_tokens, "output_tokens": 0}
}
}
yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
# Send content_block_start event
content_block_start = {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""}
}
yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
# Generate tokens
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=min(max_tokens, 256),
temperature=temperature if temperature > 0 else 1.0,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_tokens = outputs[0][input_tokens:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
output_tokens = len(generated_tokens)
# Stream text in chunks
chunk_size = 5
for i in range(0, len(generated_text), chunk_size):
chunk = generated_text[i:i+chunk_size]
content_block_delta = {
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": chunk}
}
yield f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n"
await asyncio.sleep(0.01) # Small delay for realistic streaming
# Send content_block_stop event
content_block_stop = {"type": "content_block_stop", "index": 0}
yield f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n"
# Send message_delta event
message_delta = {
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {"output_tokens": output_tokens}
}
yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n"
# Send message_stop event
message_stop = {"type": "message_stop"}
yield f"event: message_stop\ndata: {json.dumps(message_stop)}\n\n"
def handle_tool_call(tools: List[Tool], messages: List[MessageParam], generated_text: str) -> Optional[ToolUseBlock]:
"""Check if the response should trigger a tool call"""
if not tools:
return None
# Simple heuristic: check if response mentions tool names
for tool in tools:
if tool.name.lower() in generated_text.lower():
return ToolUseBlock(
type="tool_use",
id=f"toolu_{uuid.uuid4().hex[:24]}",
name=tool.name,
input={}
)
return None
# ============== Anthropic API Endpoints ==============
@app.post("/v1/messages")
async def create_message(request: AnthropicRequest):
"""
Anthropic Messages API compatible endpoint
POST /v1/messages
Supports:
- Text messages
- System prompts
- Streaming responses
- Tool/function calling
- Thinking/reasoning blocks
"""
try:
message_id = f"msg_{uuid.uuid4().hex[:24]}"
# Format messages to prompt
prompt = format_messages_to_prompt(request.messages, request.system)
# Handle streaming
if request.stream:
return StreamingResponse(
generate_stream(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature or 1.0,
top_p=request.top_p or 1.0,
message_id=message_id,
model_name=request.model
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# Non-streaming response
generated_text, input_tokens, output_tokens = generate_text(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature or 1.0,
top_p=request.top_p or 1.0
)
# Build content blocks
content_blocks = []
# Add thinking block if enabled
if request.thinking and request.thinking.type == "enabled":
thinking_text = f"Analyzing the user's request and formulating a response..."
content_blocks.append(ThinkingBlock(type="thinking", thinking=thinking_text))
# Check for tool calls
tool_use = handle_tool_call(request.tools, request.messages, generated_text) if request.tools else None
if tool_use:
content_blocks.append(TextBlock(type="text", text=generated_text))
content_blocks.append(tool_use)
stop_reason = "tool_use"
else:
content_blocks.append(TextBlock(type="text", text=generated_text))
stop_reason = "end_turn"
return AnthropicResponse(
id=message_id,
content=content_blocks,
model=request.model,
stop_reason=stop_reason,
usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============== OpenAI Compatible Endpoints ==============
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "distilgpt2"
messages: List[ChatMessage]
max_tokens: Optional[int] = 1024
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
stream: Optional[bool] = False
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""OpenAI Chat Completions API compatible endpoint"""
try:
# Convert to Anthropic format
anthropic_messages = [
MessageParam(role=msg.role if msg.role in ["user", "assistant"] else "user",
content=msg.content)
for msg in request.messages
if msg.role in ["user", "assistant"]
]
prompt = format_messages_to_prompt(anthropic_messages)
generated_text, input_tokens, output_tokens = generate_text(
prompt=prompt,
max_tokens=request.max_tokens or 1024,
temperature=request.temperature or 0.7,
top_p=request.top_p or 1.0
)
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:24]}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": generated_text},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
}
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models"""
return {
"object": "list",
"data": [
{"id": "MiniMax-M2", "object": "model", "created": int(time.time()), "owned_by": "local"},
{"id": "MiniMax-M2-Stable", "object": "model", "created": int(time.time()), "owned_by": "local"},
{"id": GENERATOR_MODEL, "object": "model", "created": int(time.time()), "owned_by": "local"}
]
}
# ============== Utility Endpoints ==============
@app.get("/")
async def root():
"""Welcome endpoint"""
return {
"message": "Docker Model Runner API (Anthropic Compatible)",
"hardware": "CPU Basic: 2 vCPU · 16 GB RAM",
"docs": "/docs",
"api_endpoints": {
"anthropic_messages": "POST /v1/messages",
"openai_chat": "POST /v1/chat/completions",
"models": "GET /v1/models"
},
"supported_features": [
"text messages",
"system prompts",
"streaming responses",
"tool/function calling",
"thinking blocks",
"metadata"
]
}
@app.get("/health")
async def health():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"hardware": "CPU Basic: 2 vCPU · 16 GB RAM",
"models_loaded": len(models) > 0
}
@app.get("/info")
async def info():
"""API information"""
return {
"name": "Docker Model Runner",
"version": "1.0.0",
"api_compatibility": ["anthropic", "openai"],
"supported_models": ["MiniMax-M2", "MiniMax-M2-Stable"],
"supported_parameters": {
"fully_supported": ["model", "messages", "max_tokens", "stream", "system", "temperature", "top_p", "tools", "tool_choice", "metadata", "thinking"],
"ignored": ["top_k", "stop_sequences", "service_tier"]
},
"message_types": {
"supported": ["text", "tool_use", "tool_result", "thinking"],
"not_supported": ["image", "document"]
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)