File size: 2,660 Bytes
dae61a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 29 11:12:02 2026

@author: atulkar
"""

import os
import gradio as gr
from transformers import pipeline

# -----------------------------
# Load Image Classification pipeline (pretrained)
# -----------------------------
# Good general-purpose ImageNet-style classifier
clf = pipeline(
    task="image-classification",
    model="google/vit-base-patch16-224"
)

# -----------------------------
# Locate example images (works locally + on HF Spaces)
# -----------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
EXAMPLES_DIR = os.path.join(BASE_DIR, "animal_images")

EXAMPLE_FILES = [
    "cat.png",
    "frog.png",
    "hippo.png",
    "jaguar.png",
    "sloth.png",
    "toucan.png",
    "turtle.png",
]

examples = []
missing = []
for fname in EXAMPLE_FILES:
    fpath = os.path.join(EXAMPLES_DIR, fname)
    if os.path.exists(fpath):
        examples.append([fpath])
    else:
        missing.append(fname)

# -----------------------------
# Prediction function
# -----------------------------
def classify_image(img):
    """
    img comes in as a PIL image (because gr.Image(type="pil"))
    Returns a dict for gr.Label: {label: confidence}
    """
    if img is None:
        return {}

    preds = clf(img, top_k=3)
    return {p["label"]: float(p["score"]) for p in preds}


# -----------------------------
# Build Gradio App
# -----------------------------
with gr.Blocks(title="Animal Image Classifier") as demo:
    gr.Markdown("# Animal Image Classifier")
    gr.Markdown(
        "Upload an animal image (or click an example). "
        "This app uses a Hugging Face `image-classification` pipeline."
    )

    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="pil", label="Input Image")
            with gr.Row():
                btn = gr.Button("Submit", variant="primary")
                clr = gr.Button("Clear")
        with gr.Column(scale=1):
            out = gr.Label(num_top_classes=3, label="Top Predictions")

    btn.click(fn=classify_image, inputs=inp, outputs=out)
    clr.click(fn=lambda: (None, {}), inputs=None, outputs=[inp, out])

    if examples:
        gr.Examples(
            examples=examples,
            inputs=inp,
            label="Examples (from ./animal_images/)",
        )

    if missing:
        gr.Markdown(
            "\n\nMake sure the folder is next to `app.py` locally, "
              "and uploaded into your Hugging Face Space repo when deploying."
        )

# -----------------------------
# Launch
# -----------------------------
if __name__ == "__main__":
    demo.launch()