| print(f"[*] Setting up...") |
|
|
| import torch |
| import requests |
| import random |
| import numpy as np |
| from io import BytesIO |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import ResNetForImageClassification |
| from collections import Counter |
|
|
| |
| ANGLES = [0, 90, 180, 270] |
| NUM_IMAGES = 500 |
| MODEL_NAME = "LH-Tech-AI/GyroScope" |
| IMG_SOURCE_URL = "https://loremflickr.com/400/400/all" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"[*] Using device: {device}") |
|
|
| |
| print(f"[*] Loading model {MODEL_NAME}...") |
| model = ResNetForImageClassification.from_pretrained(MODEL_NAME) |
| model.eval() |
| model.to(device) |
|
|
| |
| preprocess = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| results = [] |
|
|
| |
| print(f"[*] Starting download and evaluation of {NUM_IMAGES} images (In-Memory)...") |
|
|
| for i in range(1, NUM_IMAGES + 1): |
| try: |
| |
| response = requests.get(f"{IMG_SOURCE_URL}?random={i}", timeout=10) |
| img = Image.open(BytesIO(response.content)).convert("RGB") |
|
|
| |
| true_angle = random.choice(ANGLES) |
| label_idx = ANGLES.index(true_angle) |
| |
| |
| rotated_img = img.rotate(true_angle, expand=True) |
|
|
| |
| tensor = preprocess(rotated_img).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logits = model(pixel_values=tensor).logits |
| pred_idx = logits.argmax().item() |
|
|
| is_correct = (pred_idx == label_idx) |
| results.append({ |
| "true": true_angle, |
| "pred": ANGLES[pred_idx], |
| "correct": is_correct |
| }) |
|
|
| status = "✓" if is_correct else "✗" |
| percent = (i / NUM_IMAGES) * 100 |
| bar_length = 20 |
| filled_length = int(bar_length * i // NUM_IMAGES) |
| bar = '#' * filled_length + ' ' * (bar_length - filled_length) |
|
|
| status = "✓" if is_correct else "✗" |
| print(f"\rProgress: [{bar}] {percent:.1f}% ({i}/{NUM_IMAGES}) | Last result: {status}", end="") |
|
|
| except Exception as e: |
| print(f"\n[!] Error processing image {i}: {e}") |
|
|
| |
| print("\n\n" + "="*15) |
| print(" RESULTS") |
| print("="*15) |
|
|
| total_correct = sum(1 for r in results if r['correct']) |
| accuracy = (total_correct / len(results)) * 100 |
|
|
| print(f"Overall result: {total_correct}/{len(results)} correct") |
| print(f"Hit rate: {accuracy:.2f} %") |
| print("-" * 30) |
|
|
| print("Details per rotation class:") |
| for angle in ANGLES: |
| class_results = [r for r in results if r['true'] == angle] |
| if class_results: |
| correct_in_class = sum(1 for r in class_results if r['correct']) |
| class_acc = (correct_in_class / len(class_results)) * 100 |
| print(f" {angle:>3}° : {correct_in_class:>2}/{len(class_results):>2} correct ({class_acc:>6.2f}%)") |
|
|
| print("="*30) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |