| from fastapi import FastAPI, Request, HTTPException |
| from fastapi.responses import JSONResponse |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| import os |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| HF_SPACE = os.getenv("SPACE_ID", "") |
| BASE_PATH = f"/spaces/{HF_SPACE}" if HF_SPACE else "" |
|
|
| |
| app = FastAPI( |
| title="DialoGPT API", |
| description="Chatbot API using Microsoft's DialoGPT-medium model", |
| version="1.0", |
| root_path=BASE_PATH, |
| docs_url="/docs" if not BASE_PATH else f"{BASE_PATH}/docs", |
| redoc_url=None |
| ) |
|
|
| |
| try: |
| logger.info("Loading tokenizer and model...") |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
| model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
| logger.info("Model loaded successfully!") |
| except Exception as e: |
| logger.error(f"Model loading failed: {str(e)}") |
| raise RuntimeError("Model initialization failed") from e |
|
|
| |
| chat_history = {} |
|
|
| @app.get("/", include_in_schema=False) |
| async def root(): |
| return {"message": "🟢 API is running. Use /ai?query=Hello&user_id=yourname"} |
|
|
| @app.get("/ai") |
| async def chat(request: Request): |
| try: |
| |
| user_input = request.query_params.get("query", "").strip() |
| user_id = request.query_params.get("user_id", "default").strip() |
| |
| |
| if not user_input: |
| raise HTTPException( |
| status_code=400, |
| detail="Missing 'query' parameter. Usage: /ai?query=Hello&user_id=yourname" |
| ) |
| if len(user_input) > 200: |
| raise HTTPException( |
| status_code=400, |
| detail="Query too long (max 200 characters)" |
| ) |
|
|
| |
| new_input_ids = tokenizer.encode( |
| user_input + tokenizer.eos_token, |
| return_tensors='pt' |
| ) |
| |
| |
| user_history = chat_history.get(user_id, []) |
| |
| |
| bot_input_ids = torch.cat(user_history + [new_input_ids], dim=-1) if user_history else new_input_ids |
| output_ids = model.generate( |
| bot_input_ids, |
| max_new_tokens=100, |
| pad_token_id=tokenizer.eos_token_id, |
| do_sample=True, |
| top_k=50, |
| top_p=0.95 |
| ) |
| |
| |
| response = tokenizer.decode( |
| output_ids[:, bot_input_ids.shape[-1]:][0], |
| skip_special_tokens=True |
| ).strip() |
| |
| |
| chat_history[user_id] = [bot_input_ids, output_ids] |
| |
| return {"reply": response} |
| |
| except torch.cuda.OutOfMemoryError: |
| logger.error("CUDA out of memory error") |
| |
| if user_id in chat_history: |
| del chat_history[user_id] |
| raise HTTPException( |
| status_code=500, |
| detail="Memory error. Conversation history cleared. Please try again." |
| ) |
| |
| except Exception as e: |
| logger.error(f"Processing error: {str(e)}") |
| raise HTTPException( |
| status_code=500, |
| detail=f"Processing error: {str(e)}" |
| ) from e |
|
|
| @app.get("/health") |
| async def health_check(): |
| return { |
| "status": "healthy", |
| "model": "microsoft/DialoGPT-medium", |
| "users": len(chat_history), |
| "space_id": HF_SPACE |
| } |
|
|
| @app.get("/reset") |
| async def reset_history(user_id: str = "default"): |
| if user_id in chat_history: |
| del chat_history[user_id] |
| return {"status": "success", "message": f"History cleared for user {user_id}"} |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=7860, |
| log_level="info", |
| timeout_keep_alive=30 |
| ) |