botInfinity's picture
Update main.py
67d88cf verified
import io
import torch
import timm
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from torchvision import transforms
# =========================
# App Init
# =========================
app = FastAPI(title="Vehicle Type Classifier API")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
MODEL_PATH = "vehicle_classifier_best.pth"
model = None
class_names = None
# =========================
# Response Schema
# =========================
class PredictionResponse(BaseModel):
label: str
confidence: float
# =========================
# Image Transform
# =========================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# =========================
# Load Model on Startup
# =========================
@app.on_event("startup")
def load_model():
global model, class_names
checkpoint = torch.load(MODEL_PATH, map_location=device)
class_names = checkpoint["class_names"]
num_classes = len(class_names)
model = timm.create_model(
"convnext_tiny",
pretrained=False,
num_classes=num_classes
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
print("βœ… Model loaded successfully")
# =========================
# Prediction Endpoint
# =========================
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
label = class_names[pred_idx]
confidence = probs[0][pred_idx].item()
return PredictionResponse(
label=label,
confidence=round(confidence, 4)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =========================
# Health Check
# =========================
@app.get("/health")
def root():
return {"status": "Vehicle Classifier API is running πŸš—πŸοΈ"}