Adapters
Thai
art
Mangacolor / manga_colorizer_auto.py
K1Z3M1112's picture
Upload 3 files
d243596 verified
# manga_colorizer_auto.py
import os
import time
import json
import random
import requests
import psutil
import GPUtil
import cv2
import numpy as np
from pathlib import Path
from bs4 import BeautifulSoup
from PIL import Image
import cloudscraper
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import threading
from queue import Queue
import warnings
warnings.filterwarnings('ignore')
# ==================== Resource Management ====================
class ResourceManager:
def __init__(self):
self.resource_profile = self.analyze_system_resources()
self.save_resource_profile()
def analyze_system_resources(self):
"""วิเคราะห์ทรัพยากรระบบอัตโนมัติ"""
try:
profile = {
'gpu_available': torch.cuda.is_available(),
'gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
'gpu_memory': [],
'cpu_cores': psutil.cpu_count(logical=False),
'cpu_threads': psutil.cpu_count(logical=True),
'ram_total_gb': psutil.virtual_memory().total / (1024**3),
'ram_available_gb': psutil.virtual_memory().available / (1024**3),
'disk_free_gb': psutil.disk_usage('.').free / (1024**3),
}
# ตรวจสอบ GPU memory
if profile['gpu_available']:
try:
gpus = GPUtil.getGPUs()
for gpu in gpus:
profile['gpu_memory'].append({
'name': gpu.name,
'memory_total': gpu.memoryTotal,
'memory_free': gpu.memoryFree
})
except:
profile['gpu_memory'] = []
# กำหนดระดับระบบ
if profile['gpu_available'] and profile['ram_total_gb'] >= 16:
profile['system_type'] = 'high_end'
elif profile['gpu_available'] or profile['ram_total_gb'] >= 8:
profile['system_type'] = 'medium_end'
else:
profile['system_type'] = 'low_end'
print(f"🔍 ตรวจสอบระบบ: {profile['system_type'].upper()}")
print(f"💻 CPU: {profile['cpu_cores']} cores, {profile['cpu_threads']} threads")
print(f"🧠 RAM: {profile['ram_total_gb']:.1f}GB (ใช้ได้ {profile['ram_available_gb']:.1f}GB)")
print(f"🎮 GPU: {profile['gpu_available']} ({profile['gpu_count']} devices)")
if profile['gpu_memory']:
for gpu in profile['gpu_memory']:
print(f" - {gpu['name']}: {gpu['memory_total']}MB")
return profile
except Exception as e:
print(f"❌ ตรวจสอบทรัพยากรล้มเหลว: {e}")
# ค่าดีฟอลต์สำหรับระบบพื้นฐาน
return {
'system_type': 'low_end',
'cpu_cores': 2,
'ram_total_gb': 4,
'gpu_available': False,
'gpu_count': 0
}
def save_resource_profile(self):
"""บันทึกโปรไฟล์ทรัพยากร"""
try:
with open('resource_profile.json', 'w', encoding='utf-8') as f:
json.dump(self.resource_profile, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"❌ บันทึกโปรไฟล์ล้มเหลว: {e}")
def get_optimized_settings(self):
"""ได้การตั้งค่าที่เหมาะกับทรัพยากร"""
system_type = self.resource_profile['system_type']
if system_type == 'high_end':
return {
'batch_size': 12,
'image_size': 512,
'dataloader_workers': 6,
'max_galleries_per_cycle': 15,
'training_epochs': 25,
'use_gan': True,
'model_complexity': 'high',
'max_cache_size_gb': 30,
'parallel_download': True,
'max_images_per_gallery': 50
}
elif system_type == 'medium_end':
return {
'batch_size': 6,
'image_size': 384,
'dataloader_workers': 3,
'max_galleries_per_cycle': 8,
'training_epochs': 15,
'use_gan': False,
'model_complexity': 'medium',
'max_cache_size_gb': 15,
'parallel_download': False,
'max_images_per_gallery': 30
}
else: # low_end
return {
'batch_size': 2,
'image_size': 256,
'dataloader_workers': 1,
'max_galleries_per_cycle': 3,
'training_epochs': 8,
'use_gan': False,
'model_complexity': 'low',
'max_cache_size_gb': 5,
'parallel_download': False,
'max_images_per_gallery': 20
}
# ==================== Neural Network Models ====================
class SimpleColorizationModel(nn.Module):
def __init__(self, base_channels=32):
super().__init__()
# Encoder
self.enc1 = nn.Sequential(
nn.Conv2d(3, base_channels, 3, padding=1),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace=True),
nn.Conv2d(base_channels, base_channels, 3, padding=1),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.enc2 = nn.Sequential(
nn.Conv2d(base_channels, base_channels*2, 3, padding=1),
nn.BatchNorm2d(base_channels*2),
nn.ReLU(inplace=True),
nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1),
nn.BatchNorm2d(base_channels*2),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.enc3 = nn.Sequential(
nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1),
nn.BatchNorm2d(base_channels*4),
nn.ReLU(inplace=True),
nn.Conv2d(base_channels*4, base_channels*4, 3, padding=1),
nn.BatchNorm2d(base_channels*4),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(base_channels*4, base_channels*8, 3, padding=1),
nn.BatchNorm2d(base_channels*8),
nn.ReLU(inplace=True),
nn.Conv2d(base_channels*8, base_channels*8, 3, padding=1),
nn.BatchNorm2d(base_channels*8),
nn.ReLU(inplace=True),
)
# Decoder
self.dec1 = nn.Sequential(
nn.ConvTranspose2d(base_channels*8, base_channels*4, 2, stride=2),
nn.BatchNorm2d(base_channels*4),
nn.ReLU(inplace=True),
)
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(base_channels*8, base_channels*2, 2, stride=2),
nn.BatchNorm2d(base_channels*2),
nn.ReLU(inplace=True),
)
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(base_channels*4, base_channels, 2, stride=2),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace=True),
)
# Output
self.output = nn.Conv2d(base_channels*2, 3, 3, padding=1)
def forward(self, x):
# Encoder
e1 = self.enc1(x) # base_channels x H/2 x W/2
e2 = self.enc2(e1) # base_channels*2 x H/4 x W/4
e3 = self.enc3(e2) # base_channels*4 x H/8 x W/8
# Bottleneck
b = self.bottleneck(e3) # base_channels*8 x H/8 x W/8
# Decoder with skip connections
d1 = self.dec1(b) # base_channels*4 x H/4 x W/4
d1 = torch.cat([d1, e2], dim=1) # base_channels*8 x H/4 x W/4
d2 = self.dec2(d1) # base_channels*2 x H/2 x W/2
d2 = torch.cat([d2, e1], dim=1) # base_channels*4 x H/2 x W/2
d3 = self.dec3(d2) # base_channels x H x W
output = torch.sigmoid(self.output(d3))
return output
class AdvancedColorizationModel(nn.Module):
def __init__(self, base_channels=64, use_attention=True):
super().__init__()
self.use_attention = use_attention
# Enhanced Encoder
self.enc1 = self._make_encoder_block(3, base_channels)
self.enc2 = self._make_encoder_block(base_channels, base_channels*2)
self.enc3 = self._make_encoder_block(base_channels*2, base_channels*4)
self.enc4 = self._make_encoder_block(base_channels*4, base_channels*8)
# Attention layers (ถ้าเปิดใช้)
if use_attention:
self.attention1 = nn.MultiheadAttention(base_channels*8, 8)
self.attention2 = nn.MultiheadAttention(base_channels*4, 8)
# Enhanced Decoder with more skip connections
self.dec1 = self._make_decoder_block(base_channels*8, base_channels*4)
self.dec2 = self._make_decoder_block(base_channels*8, base_channels*2) # + skip
self.dec3 = self._make_decoder_block(base_channels*4, base_channels) # + skip
self.dec4 = self._make_decoder_block(base_channels*2, base_channels//2) # + skip
# Final output with residual connection
self.output = nn.Sequential(
nn.Conv2d(base_channels, base_channels//2, 3, padding=1),
nn.BatchNorm2d(base_channels//2),
nn.ReLU(inplace=True),
nn.Conv2d(base_channels//2, 3, 3, padding=1)
)
self.residual_conv = nn.Conv2d(3, 3, 1)
def _make_encoder_block(self, in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def _make_decoder_block(self, in_ch, out_ch):
return nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Store input for residual connection
x_input = x
# Encoder
e1 = self.enc1(x) # base_channels
e2 = self.enc2(e1) # base_channels*2
e3 = self.enc3(e2) # base_channels*4
e4 = self.enc4(e3) # base_channels*8
# Attention (if enabled)
if self.use_attention:
# Reshape for attention
b, c, h, w = e4.shape
e4_flat = e4.view(b, c, -1).permute(2, 0, 1) # (h*w, b, c)
e4_att, _ = self.attention1(e4_flat, e4_flat, e4_flat)
e4 = e4_att.permute(1, 2, 0).view(b, c, h, w)
# Decoder with skip connections
d1 = self.dec1(e4) # base_channels*4
d1 = torch.cat([d1, e3], dim=1) # base_channels*8
d2 = self.dec2(d1) # base_channels*2
d2 = torch.cat([d2, e2], dim=1) # base_channels*4
d3 = self.dec3(d2) # base_channels
d3 = torch.cat([d3, e1], dim=1) # base_channels*2
d4 = self.dec4(d3) # base_channels//2
# Output with residual connection
output = self.output(d4)
residual = self.residual_conv(x_input)
output = output + residual
return torch.sigmoid(output)
# ==================== Data Collection ====================
class SmartDataCollector:
def __init__(self, settings):
self.settings = settings
self.scraper = cloudscraper.create_scraper()
self.downloaded_count = 0
self.processed_galleries = set()
self.load_processed_list()
# สร้างโฟลเดอร์เก็บข้อมูล
os.makedirs('training_data', exist_ok=True)
os.makedirs('temp_downloads', exist_ok=True)
def load_processed_list(self):
"""โหลดรายการ gallery ที่ประมวลผลแล้ว"""
try:
if os.path.exists('processed_galleries.json'):
with open('processed_galleries.json', 'r') as f:
self.processed_galleries = set(json.load(f))
except:
self.processed_galleries = set()
def save_processed_list(self):
"""บันทึกรายการ gallery ที่ประมวลผลแล้ว"""
try:
with open('processed_galleries.json', 'w') as f:
json.dump(list(self.processed_galleries), f)
except Exception as e:
print(f"❌ บันทึกรายการล้มเหลว: {e}")
def get_sample_image_urls(self):
"""ให้ URL ภาพตัวอย่างสำหรับทดสอบ (หลีกเลี่ยงการดาวน์โหลดจริง)"""
print("📝 ใช้ภาพตัวอย่างสำหรับทดสอบ...")
# สร้างภาพทดสอบสีและขาวดำ
sample_images = []
for i in range(20): # สร้าง 20 ภาพตัวอย่าง
# สร้างภาพสีสุ่ม
color_img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
gray_img = cv2.cvtColor(color_img, cv2.COLOR_RGB2GRAY)
gray_img = np.stack([gray_img]*3, axis=-1)
# บันทึกภาพ
color_path = f'training_data/sample_color_{i:03d}.jpg'
gray_path = f'training_data/sample_gray_{i:03d}.jpg'
cv2.imwrite(color_path, color_img)
cv2.imwrite(gray_path, gray_img)
sample_images.append((gray_path, color_path))
return sample_images
def adaptive_data_collection(self):
"""รวบรวมข้อมูลตามทรัพยากรที่มี"""
max_galleries = self.settings['max_galleries_per_cycle']
print(f"📥 รวบรวมข้อมูล: {max_galleries} galleries")
# สำหรับการทดสอบ ใช้ภาพตัวอย่าง
if self.settings['system_type'] == 'low_end':
print("💡 ระบบ Low-end: ใช้ภาพตัวอย่าง")
return self.get_sample_image_urls()
else:
# พยายามดาวน์โหลดข้อมูลจริง (optional)
try:
return self.download_limited_data(max_galleries)
except Exception as e:
print(f"⚠️ ดาวน์โหลดล้มเหลว, ใช้ภาพตัวอย่าง: {e}")
return self.get_sample_image_urls()
def download_limited_data(self, max_galleries):
"""ดาวน์โหลดข้อมูลจำนวนจำกัด"""
# ตัวอย่างการดาวน์โหลดจากแหล่งข้อมูลสาธารณะ
image_pairs = []
# ดาวน์โหลดภาพตัวอย่างจากแหล่งสาธารณะ (เช่น Unsplash)
try:
print("🌐 ดาวน์โหลดภาพตัวอย่างจากแหล่งสาธารณะ...")
# ตัวอย่าง URL ภาพสาธารณะ (สามารถเพิ่มได้)
sample_urls = [
"https://images.unsplash.com/photo-1541963463532-d68292c34b19",
"https://images.unsplash.com/photo-1551963831-b3b1ca40c98e",
"https://images.unsplash.com/photo-1551782450-a2132b4ba21d"
]
for i, url in enumerate(sample_urls[:3]): # จำกัดที่ 3 ภาพ
try:
response = self.scraper.get(url, timeout=10)
if response.status_code == 200:
# บันทึกภาพสี
color_path = f'training_data/web_color_{i:03d}.jpg'
with open(color_path, 'wb') as f:
f.write(response.content)
# สร้างภาพ grayscale
img = cv2.imread(color_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_rgb = np.stack([gray]*3, axis=-1)
gray_path = f'training_data/web_gray_{i:03d}.jpg'
cv2.imwrite(gray_path, gray_rgb)
image_pairs.append((gray_path, color_path))
print(f"✅ ดาวน์โหลดภาพ {i+1} สำเร็จ")
except Exception as e:
print(f"❌ ดาวน์โหลดภาพ {i+1} ล้มเหลว: {e}")
except Exception as e:
print(f"❌ การดาวน์โหลดล้มเหลว: {e}")
# หากดาวน์โหลดไม่สำเร็จ ให้ใช้ภาพตัวอย่าง
if not image_pairs:
image_pairs = self.get_sample_image_urls()
return image_pairs
# ==================== Dataset and DataLoader ====================
class MangaDataset(Dataset):
def __init__(self, image_pairs, target_size=256, augment=True):
self.image_pairs = image_pairs
self.target_size = target_size
self.augment = augment
# Data augmentation
if augment:
self.transform = A.Compose([
A.HorizontalFlip(p=0.3),
A.RandomRotate90(p=0.2),
A.RandomBrightnessContrast(p=0.2),
A.HueSaturationValue(p=0.2),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2(),
])
else:
self.transform = A.Compose([
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2(),
])
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx):
try:
gray_path, color_path = self.image_pairs[idx]
# โหลดภาพ
gray_img = cv2.imread(gray_path)
color_img = cv2.imread(color_path)
if gray_img is None or color_img is None:
raise ValueError("Cannot load image")
# แปลงสี
gray_img = cv2.cvtColor(gray_img, cv2.COLOR_BGR2RGB)
color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)
# ปรับขนาด
gray_img = cv2.resize(gray_img, (self.target_size, self.target_size))
color_img = cv2.resize(color_img, (self.target_size, self.target_size))
# Data augmentation
if self.augment:
transformed = self.transform(image=color_img, mask=gray_img)
color_tensor = transformed['image']
gray_tensor = transformed['mask']
else:
color_tensor = self.transform(image=color_img)['image']
gray_tensor = self.transform(image=gray_img)['image']
return gray_tensor, color_tensor
except Exception as e:
print(f"❌ โหลดภาพล้มเหลว: {e}")
# ส่งคืน dummy data
dummy = torch.zeros(3, self.target_size, self.target_size)
return dummy, dummy
# ==================== Training System ====================
class AdaptiveTrainingSystem:
def __init__(self):
self.resource_manager = ResourceManager()
self.settings = self.resource_manager.get_optimized_settings()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# เลือกโมเดลตามทรัพยากร
if self.settings['model_complexity'] == 'high':
self.model = AdvancedColorizationModel(base_channels=64, use_attention=True)
elif self.settings['model_complexity'] == 'medium':
self.model = AdvancedColorizationModel(base_channels=48, use_attention=False)
else:
self.model = SimpleColorizationModel(base_channels=32)
self.model = self.model.to(self.device)
# ตั้งค่า optimizer
self.setup_optimizer()
# Monitoring
self.performance_metrics = {
'loss_history': [],
'memory_usage': [],
'training_time': []
}
print("🎯 โมเดลที่ใช้:", self.model.__class__.__name__)
print("⚙️ การตั้งค่า:", self.settings)
def setup_optimizer(self):
"""ตั้งค่า optimizer ตามทรัพยากร"""
if self.settings['system_type'] == 'high_end':
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=0.001,
weight_decay=0.01
)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=self.settings['training_epochs']
)
elif self.settings['system_type'] == 'medium_end':
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=0.001
)
self.scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer,
step_size=5,
gamma=0.8
)
else: # low_end
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=0.0005
)
self.scheduler = None
def monitor_resources(self):
"""ตรวจสอบทรัพยากรระหว่างการทำงาน"""
try:
memory_usage = psutil.virtual_memory().percent
self.performance_metrics['memory_usage'].append(memory_usage)
# ปรับการตั้งค่าตามทรัพยากรปัจจุบัน
if memory_usage > 85:
self.reduce_memory_usage()
# ตรวจสอบ GPU memory
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / (1024**3)
if hasattr(self, 'last_gpu_memory') and gpu_memory > self.last_gpu_memory * 1.5:
self.reduce_batch_size()
self.last_gpu_memory = gpu_memory
except Exception as e:
print(f"⚠️ ตรวจสอบทรัพยากรล้มเหลว: {e}")
def reduce_memory_usage(self):
"""ลดการใช้ memory เมื่อจำเป็น"""
print("⚠️ หน่วยความจำใกล้เต็ม, ลดการใช้ทรัพยากร...")
# ลด batch size
if self.settings['batch_size'] > 1:
self.settings['batch_size'] = max(1, self.settings['batch_size'] // 2)
print(f"🔽 ลด batch size เป็น: {self.settings['batch_size']}")
# ล้าง cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# เรียก garbage collector
import gc
gc.collect()
def reduce_batch_size(self):
"""ลด batch size สำหรับ GPU"""
current_bs = self.settings['batch_size']
if current_bs > 1:
self.settings['batch_size'] = current_bs // 2
print(f"🔽 ลด batch size เป็น: {self.settings['batch_size']}")
def create_dataloader(self, image_pairs):
"""สร้าง DataLoader"""
dataset = MangaDataset(
image_pairs,
target_size=self.settings['image_size'],
augment=True
)
return DataLoader(
dataset,
batch_size=self.settings['batch_size'],
shuffle=True,
num_workers=self.settings['dataloader_workers'],
pin_memory=torch.cuda.is_available()
)
def train_epoch(self, dataloader, epoch):
"""ฝึกหนึ่ง epoch"""
self.model.train()
total_loss = 0
criterion = nn.L1Loss() # MAE loss
start_time = time.time()
for batch_idx, (gray, color) in enumerate(dataloader):
# ตรวจสอบทรัพยากร
self.monitor_resources()
gray = gray.to(self.device)
color = color.to(self.device)
self.optimizer.zero_grad()
output = self.model(gray)
loss = criterion(output, color)
loss.backward()
# Gradient clipping สำหรับระบบ low-end
if self.settings['system_type'] == 'low_end':
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=1.0
)
self.optimizer.step()
total_loss += loss.item()
if batch_idx % 5 == 0:
print(f'📦 Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}')
avg_loss = total_loss / len(dataloader)
epoch_time = time.time() - start_time
# อัปเดต scheduler
if self.scheduler:
self.scheduler.step()
self.performance_metrics['loss_history'].append(avg_loss)
self.performance_metrics['training_time'].append(epoch_time)
return avg_loss
def save_checkpoint(self, epoch, loss):
"""บันทึก checkpoint"""
try:
os.makedirs('models', exist_ok=True)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'loss': loss,
'settings': self.settings,
'performance_metrics': self.performance_metrics
}
checkpoint_path = f"models/checkpoint_epoch_{epoch:04d}.pth"
torch.save(checkpoint, checkpoint_path)
# บันทึกโมเดลที่ดีที่สุด
if not hasattr(self, 'best_loss') or loss < self.best_loss:
self.best_loss = loss
torch.save(self.model.state_dict(), "models/best_model.pth")
print(f"🏆 บันทึกโมเดลที่ดีที่สุด (Loss: {loss:.6f})")
print(f"💾 บันทึก checkpoint: {checkpoint_path}")
# ลบ checkpoint เก่า
self.cleanup_old_checkpoints()
except Exception as e:
print(f"❌ บันทึก checkpoint ล้มเหลว: {e}")
def cleanup_old_checkpoints(self):
"""ลบ checkpoint เก่า"""
try:
checkpoints = [f for f in os.listdir('models') if f.startswith('checkpoint_')]
if len(checkpoints) > 3: # เก็บไว้แค่ 3 ไฟล์ล่าสุด
checkpoints.sort()
for old_checkpoint in checkpoints[:-3]:
os.remove(os.path.join('models', old_checkpoint))
except Exception as e:
print(f"⚠️ ลบ checkpoint เก่าล้มเหลว: {e}")
def load_checkpoint(self, checkpoint_path):
"""โหลด checkpoint"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if self.scheduler and checkpoint['scheduler_state_dict']:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.settings.update(checkpoint.get('settings', {}))
self.performance_metrics.update(checkpoint.get('performance_metrics', {}))
print(f"✅ โหลด checkpoint: {checkpoint_path}")
return checkpoint['epoch']
except Exception as e:
print(f"❌ โหลด checkpoint ล้มเหลว: {e}")
return 0
def colorize_image(self, image_path, output_path=None):
"""ลงสีภาพเดี่ยว"""
self.model.eval()
try:
# โหลดภาพ
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"ไม่สามารถโหลดภาพ: {image_path}")
# แปลงเป็น grayscale (เพื่อทดสอบ)
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray_rgb = np.stack([gray] * 3, axis=-1)
else:
gray_rgb = np.stack([image] * 3, axis=-1)
# ปรับขนาด
h, w = gray_rgb.shape[:2]
target_size = self.settings['image_size']
scale = target_size / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
gray_resized = cv2.resize(gray_rgb, (new_w, new_h))
# Preprocess
gray_tensor = torch.from_numpy(gray_resized).permute(2, 0, 1).unsqueeze(0)
gray_tensor = gray_tensor.float() / 255.0
gray_tensor = (gray_tensor - 0.5) / 0.5 # Normalize
# Prediction
with torch.no_grad():
output = self.model(gray_tensor.to(self.device))
output = output.squeeze(0).cpu().numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 0.5 + 0.5) * 255 # Denormalize
output = output.astype(np.uint8)
# ปรับขนาดกลับ
output = cv2.resize(output, (w, h))
if output_path:
cv2.imwrite(output_path, cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
print(f"✅ บันทึกภาพลงสี: {output_path}")
return output
except Exception as e:
print(f"❌ ลงสีภาพล้มเหลว: {e}")
return None
# ==================== Main System ====================
class AutoMangaColorizationSystem:
def __init__(self):
print("🚀 เริ่มต้นระบบลงสีมังงะอัตโนมัติ...")
print("=" * 50)
self.training_system = AdaptiveTrainingSystem()
self.data_collector = SmartDataCollector(self.training_system.settings)
self.cycle_count = 0
# พยายามโหลด checkpoint ล่าสุด
self.load_latest_checkpoint()
def load_latest_checkpoint(self):
"""โหลด checkpoint ล่าสุด"""
try:
if os.path.exists('models'):
checkpoints = [f for f in os.listdir('models') if f.startswith('checkpoint_')]
if checkpoints:
latest_checkpoint = max(checkpoints)
checkpoint_path = os.path.join('models', latest_checkpoint)
self.start_epoch = self.training_system.load_checkpoint(checkpoint_path) + 1
print(f"🔄 ดำเนินการต่อจาก epoch {self.start_epoch}")
return
except Exception as e:
print(f"⚠️ โหลด checkpoint ล้มเหลว: {e}")
self.start_epoch = 0
def run_continuous_learning(self, total_cycles=100):
"""รันระบบเรียนรู้ต่อเนื่อง"""
print(f"🎯 เริ่มการเรียนรู้ {total_cycles} วงจร...")
for cycle in range(self.start_epoch, total_cycles):
self.cycle_count = cycle
print(f"\n{'='*60}")
print(f"🔄 วงจรที่ {cycle + 1}/{total_cycles}")
print(f"{'='*60}")
try:
# 1. รวบรวมข้อมูล
print("📥 กำลังรวบรวมข้อมูล...")
image_pairs = self.data_collector.adaptive_data_collection()
if not image_pairs:
print("⚠️ ไม่พบข้อมูล, ข้ามการฝึกในวงจรนี้")
continue
print(f"📊 ได้ข้อมูล {len(image_pairs)} คู่ภาพ")
# 2. สร้าง DataLoader
dataloader = self.training_system.create_dataloader(image_pairs)
# 3. ฝึกโมเดล
print("🎯 เริ่มการฝึก...")
avg_loss = self.training_system.train_epoch(dataloader, cycle)
print(f"✅ วงจร {cycle + 1} เสร็จสิ้น, Loss: {avg_loss:.6f}")
# 4. บันทึกผลลัพธ์
if cycle % 3 == 0: # บันทึกทุก 3 วงจร
self.training_system.save_checkpoint(cycle, avg_loss)
# 5. ทดสอบโมเดล
if cycle % 5 == 0 and image_pairs:
self.test_current_model(cycle, image_pairs[0][0])
# 6. จัดการทรัพยากร
self.cleanup_resources()
# 7. พักระหว่างวงจร
self.adaptive_sleep(cycle)
except Exception as e:
print(f"❌ วงจร {cycle} ล้มเหลว: {e}")
self.system_recovery()
def test_current_model(self, cycle, test_image_path):
"""ทดสอบโมเดลปัจจุบัน"""
try:
if os.path.exists(test_image_path):
output_path = f"results/test_cycle_{cycle:04d}.jpg"
os.makedirs('results', exist_ok=True)
result = self.training_system.colorize_image(test_image_path, output_path)
if result is not None:
print(f"🧪 ทดสอบโมเดลวงจร {cycle} สำเร็จ")
else:
print(f"⚠️ การทดสอบล้มเหลว")
else:
print("⚠️ ไม่พบไฟล์ทดสอบ")
except Exception as e:
print(f"⚠️ การทดสอบล้มเหลว: {e}")
def cleanup_resources(self):
"""ทำความสะอาดทรัพยากร"""
try:
# ล้าง cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ลบไฟล์ชั่วคราว
temp_files = [f for f in os.listdir('temp_downloads') if f.endswith('.tmp')]
for temp_file in temp_files:
try:
os.remove(os.path.join('temp_downloads', temp_file))
except:
pass
# เรียก garbage collector
import gc
gc.collect()
except Exception as e:
print(f"⚠️ ทำความสะอาดทรัพยากรล้มเหลว: {e}")
def adaptive_sleep(self, cycle):
"""พักระหว่างวงจรตามทรัพยากร"""
system_type = self.training_system.settings['system_type']
if system_type == 'high_end':
sleep_time = 120 # 2 นาที
elif system_type == 'medium_end':
sleep_time = 180 # 3 นาที
else:
sleep_time = 300 # 5 นาที
print(f"⏳ พัก {sleep_time} วินาที...")
# พักแบบแบ่งช่วงเพื่อตรวจสอบทรัพยากร
for i in range(sleep_time // 30):
time.sleep(30)
self.training_system.monitor_resources()
def system_recovery(self):
"""กู้คืนระบบเมื่อเกิดข้อผิดพลาด"""
print("🔄 พยายามกู้คืนระบบ...")
try:
# ล้าง memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
# รีสตาร์ทระบบย่อย
self.training_system = AdaptiveTrainingSystem()
print("✅ กู้คืนระบบสำเร็จ")
except Exception as e:
print(f"❌ กู้คืนระบบล้มเหลว: {e}")
def colorize_single_image(self, image_path, output_path=None):
"""ลงสีภาพเดี่ยว"""
if output_path is None:
output_path = f"colored_{os.path.basename(image_path)}"
print(f"🎨 กำลังลงสีภาพ: {image_path}")
result = self.training_system.colorize_image(image_path, output_path)
if result is not None:
print(f"✅ ลงสีภาพสำเร็จ: {output_path}")
return True
else:
print("❌ ลงสีภาพล้มเหลว")
return False
# ==================== Main Execution ====================
def main():
"""ฟังก์ชันหลัก"""
print("🎨 Manga Colorization System")
print("สร้างโดย: Auto AI System")
print("=" * 50)
try:
# สร้างระบบ
system = AutoMangaColorizationSystem()
# เริ่มการเรียนรู้
print("\n🎯 เริ่มกระบวนการเรียนรู้...")
system.run_continuous_learning(total_cycles=50)
except KeyboardInterrupt:
print("\n⏹️ หยุดระบบโดยผู้ใช้")
except Exception as e:
print(f"\n💥 ระบบหยุดทำงาน: {e}")
finally:
print("\n🧹 ทำความสะอาดทรัพยากร...")
print("✅ โปรแกรมสิ้นสุดการทำงาน")
def quick_test():
"""ทดสอบระบบอย่างรวดเร็ว"""
print("🧪 โหมดทดสอบอย่างรวดเร็ว...")
try:
system = AutoMangaColorizationSystem()
# ทดสอบด้วยวงจรเดียว
print("🔬 ทดสอบ 1 วงจร...")
image_pairs = system.data_collector.adaptive_data_collection()
if image_pairs:
dataloader = system.training_system.create_dataloader(image_pairs)
loss = system.training_system.train_epoch(dataloader, 0)
print(f"📊 Loss: {loss:.6f}")
# ทดสอบลงสี
if image_pairs:
test_image = image_pairs[0][0]
system.test_current_model(0, test_image)
print("✅ การทดสอบสำเร็จ")
except Exception as e:
print(f"❌ การทดสอบล้มเหลว: {e}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Manga Colorization System')
parser.add_argument('--test', action='store_true', help='โหมดทดสอบอย่างรวดเร็ว')
parser.add_argument('--colorize', type=str, help='ลงสีภาพเดี่ยว')
parser.add_argument('--output', type=str, help='ไฟล์ผลลัพธ์สำหรับโหมด colorize')
args = parser.parse_args()
if args.test:
quick_test()
elif args.colorize:
system = AutoMangaColorizationSystem()
system.colorize_single_image(args.colorize, args.output)
else:
main()