print(f"[*] Setting up...") import torch import requests import random import numpy as np from io import BytesIO from PIL import Image from torchvision import transforms from transformers import ResNetForImageClassification from collections import Counter # --- 1. CONFIGURATION & SETUP --- ANGLES = [0, 90, 180, 270] NUM_IMAGES = 500 MODEL_NAME = "LH-Tech-AI/GyroScope" IMG_SOURCE_URL = "https://loremflickr.com/400/400/all" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[*] Using device: {device}") # Modell laden print(f"[*] Loading model {MODEL_NAME}...") model = ResNetForImageClassification.from_pretrained(MODEL_NAME) model.eval() model.to(device) # Vorverarbeitung preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) results = [] # --- 2. EVALUATIONS-LOOP --- print(f"[*] Starting download and evaluation of {NUM_IMAGES} images (In-Memory)...") for i in range(1, NUM_IMAGES + 1): try: # Load image into RAM response = requests.get(f"{IMG_SOURCE_URL}?random={i}", timeout=10) img = Image.open(BytesIO(response.content)).convert("RGB") # Apply random rotation true_angle = random.choice(ANGLES) label_idx = ANGLES.index(true_angle) # Rotate image rotated_img = img.rotate(true_angle, expand=True) # Prediction tensor = preprocess(rotated_img).unsqueeze(0).to(device) with torch.no_grad(): logits = model(pixel_values=tensor).logits pred_idx = logits.argmax().item() is_correct = (pred_idx == label_idx) results.append({ "true": true_angle, "pred": ANGLES[pred_idx], "correct": is_correct }) status = "✓" if is_correct else "✗" percent = (i / NUM_IMAGES) * 100 bar_length = 20 filled_length = int(bar_length * i // NUM_IMAGES) bar = '#' * filled_length + ' ' * (bar_length - filled_length) status = "✓" if is_correct else "✗" print(f"\rProgress: [{bar}] {percent:.1f}% ({i}/{NUM_IMAGES}) | Last result: {status}", end="") except Exception as e: print(f"\n[!] Error processing image {i}: {e}") # --- 3. RESULTS --- print("\n\n" + "="*15) print(" RESULTS") print("="*15) total_correct = sum(1 for r in results if r['correct']) accuracy = (total_correct / len(results)) * 100 print(f"Overall result: {total_correct}/{len(results)} correct") print(f"Hit rate: {accuracy:.2f} %") print("-" * 30) print("Details per rotation class:") for angle in ANGLES: class_results = [r for r in results if r['true'] == angle] if class_results: correct_in_class = sum(1 for r in class_results if r['correct']) class_acc = (correct_in_class / len(class_results)) * 100 print(f" {angle:>3}° : {correct_in_class:>2}/{len(class_results):>2} correct ({class_acc:>6.2f}%)") print("="*30) # Result of our benchmark: # =============== # RESULTS # =============== # Overall result: 411/500 correct # Hit rate: 82.20 % # ------------------------------ # Details per rotation class: # 0° : 96/124 correct ( 77.42%) # 90° : 103/119 correct ( 86.55%) # 180° : 112/129 correct ( 86.82%) # 270° : 100/128 correct ( 78.12%) # ==============================