Spaces:
Running
Running
| import os | |
| import io | |
| import logging | |
| from typing import Tuple | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import numpy as np | |
| import tensorflow as tf | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("vehicle-predictor") | |
| MODEL_FILENAME = "complete_model_model.h5" | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), MODEL_FILENAME) | |
| IMG_SIZE = (224, 224) | |
| CLASS_NAMES = [ | |
| 'Ambulance', 'Bicycle', 'Boat', 'Bus', 'Car', 'Helicopter', 'Limousine', | |
| 'Motorcycle', 'PickUp', 'Segway', 'Snowmobile', 'Tank', 'Taxi', 'Truck', 'Van' | |
| ] | |
| app = FastAPI(title="Vehicle Type Predictor") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # you can tighten this later if needed | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load model at startup | |
| try: | |
| logger.info("π Loading model...") | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| logger.info("β Model loaded successfully.") | |
| except Exception as e: | |
| logger.exception("β Model load failed") | |
| model = None | |
| class PredictionResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| def preprocess_image_file(file_bytes: bytes) -> np.ndarray: | |
| img = Image.open(io.BytesIO(file_bytes)).convert("RGB") | |
| img = img.resize(IMG_SIZE) | |
| arr = np.asarray(img).astype("float32") / 255.0 | |
| arr = np.expand_dims(arr, axis=0) | |
| return arr | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| contents = await file.read() | |
| x = preprocess_image_file(contents) | |
| preds = model.predict(x) | |
| idx = int(np.argmax(preds[0])) | |
| label = CLASS_NAMES[idx] | |
| confidence = float(preds[0][idx]) | |
| logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}") | |
| return PredictionResponse(label=label, confidence=confidence) | |
| except Exception as e: | |
| logger.exception("Prediction failed") | |
| raise HTTPException(status_code=500, detail="Prediction failed") | |
| def health(): | |
| return {"status": "ok", "model_loaded": model is not None} | |