|
|
|
|
|
import os |
|
|
import time |
|
|
import json |
|
|
import random |
|
|
import requests |
|
|
import psutil |
|
|
import GPUtil |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from bs4 import BeautifulSoup |
|
|
from PIL import Image |
|
|
import cloudscraper |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import albumentations as A |
|
|
from albumentations.pytorch import ToTensorV2 |
|
|
import threading |
|
|
from queue import Queue |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
class ResourceManager: |
|
|
def __init__(self): |
|
|
self.resource_profile = self.analyze_system_resources() |
|
|
self.save_resource_profile() |
|
|
|
|
|
def analyze_system_resources(self): |
|
|
"""วิเคราะห์ทรัพยากรระบบอัตโนมัติ""" |
|
|
try: |
|
|
profile = { |
|
|
'gpu_available': torch.cuda.is_available(), |
|
|
'gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0, |
|
|
'gpu_memory': [], |
|
|
'cpu_cores': psutil.cpu_count(logical=False), |
|
|
'cpu_threads': psutil.cpu_count(logical=True), |
|
|
'ram_total_gb': psutil.virtual_memory().total / (1024**3), |
|
|
'ram_available_gb': psutil.virtual_memory().available / (1024**3), |
|
|
'disk_free_gb': psutil.disk_usage('.').free / (1024**3), |
|
|
} |
|
|
|
|
|
|
|
|
if profile['gpu_available']: |
|
|
try: |
|
|
gpus = GPUtil.getGPUs() |
|
|
for gpu in gpus: |
|
|
profile['gpu_memory'].append({ |
|
|
'name': gpu.name, |
|
|
'memory_total': gpu.memoryTotal, |
|
|
'memory_free': gpu.memoryFree |
|
|
}) |
|
|
except: |
|
|
profile['gpu_memory'] = [] |
|
|
|
|
|
|
|
|
if profile['gpu_available'] and profile['ram_total_gb'] >= 16: |
|
|
profile['system_type'] = 'high_end' |
|
|
elif profile['gpu_available'] or profile['ram_total_gb'] >= 8: |
|
|
profile['system_type'] = 'medium_end' |
|
|
else: |
|
|
profile['system_type'] = 'low_end' |
|
|
|
|
|
print(f"🔍 ตรวจสอบระบบ: {profile['system_type'].upper()}") |
|
|
print(f"💻 CPU: {profile['cpu_cores']} cores, {profile['cpu_threads']} threads") |
|
|
print(f"🧠 RAM: {profile['ram_total_gb']:.1f}GB (ใช้ได้ {profile['ram_available_gb']:.1f}GB)") |
|
|
print(f"🎮 GPU: {profile['gpu_available']} ({profile['gpu_count']} devices)") |
|
|
if profile['gpu_memory']: |
|
|
for gpu in profile['gpu_memory']: |
|
|
print(f" - {gpu['name']}: {gpu['memory_total']}MB") |
|
|
|
|
|
return profile |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ ตรวจสอบทรัพยากรล้มเหลว: {e}") |
|
|
|
|
|
return { |
|
|
'system_type': 'low_end', |
|
|
'cpu_cores': 2, |
|
|
'ram_total_gb': 4, |
|
|
'gpu_available': False, |
|
|
'gpu_count': 0 |
|
|
} |
|
|
|
|
|
def save_resource_profile(self): |
|
|
"""บันทึกโปรไฟล์ทรัพยากร""" |
|
|
try: |
|
|
with open('resource_profile.json', 'w', encoding='utf-8') as f: |
|
|
json.dump(self.resource_profile, f, indent=2, ensure_ascii=False) |
|
|
except Exception as e: |
|
|
print(f"❌ บันทึกโปรไฟล์ล้มเหลว: {e}") |
|
|
|
|
|
def get_optimized_settings(self): |
|
|
"""ได้การตั้งค่าที่เหมาะกับทรัพยากร""" |
|
|
system_type = self.resource_profile['system_type'] |
|
|
|
|
|
if system_type == 'high_end': |
|
|
return { |
|
|
'batch_size': 12, |
|
|
'image_size': 512, |
|
|
'dataloader_workers': 6, |
|
|
'max_galleries_per_cycle': 15, |
|
|
'training_epochs': 25, |
|
|
'use_gan': True, |
|
|
'model_complexity': 'high', |
|
|
'max_cache_size_gb': 30, |
|
|
'parallel_download': True, |
|
|
'max_images_per_gallery': 50 |
|
|
} |
|
|
elif system_type == 'medium_end': |
|
|
return { |
|
|
'batch_size': 6, |
|
|
'image_size': 384, |
|
|
'dataloader_workers': 3, |
|
|
'max_galleries_per_cycle': 8, |
|
|
'training_epochs': 15, |
|
|
'use_gan': False, |
|
|
'model_complexity': 'medium', |
|
|
'max_cache_size_gb': 15, |
|
|
'parallel_download': False, |
|
|
'max_images_per_gallery': 30 |
|
|
} |
|
|
else: |
|
|
return { |
|
|
'batch_size': 2, |
|
|
'image_size': 256, |
|
|
'dataloader_workers': 1, |
|
|
'max_galleries_per_cycle': 3, |
|
|
'training_epochs': 8, |
|
|
'use_gan': False, |
|
|
'model_complexity': 'low', |
|
|
'max_cache_size_gb': 5, |
|
|
'parallel_download': False, |
|
|
'max_images_per_gallery': 20 |
|
|
} |
|
|
|
|
|
|
|
|
class SimpleColorizationModel(nn.Module): |
|
|
def __init__(self, base_channels=32): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.enc1 = nn.Sequential( |
|
|
nn.Conv2d(3, base_channels, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(base_channels, base_channels, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2) |
|
|
) |
|
|
|
|
|
self.enc2 = nn.Sequential( |
|
|
nn.Conv2d(base_channels, base_channels*2, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels*2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(base_channels*2, base_channels*2, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels*2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2) |
|
|
) |
|
|
|
|
|
self.enc3 = nn.Sequential( |
|
|
nn.Conv2d(base_channels*2, base_channels*4, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels*4), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(base_channels*4, base_channels*4, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels*4), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2) |
|
|
) |
|
|
|
|
|
|
|
|
self.bottleneck = nn.Sequential( |
|
|
nn.Conv2d(base_channels*4, base_channels*8, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels*8), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(base_channels*8, base_channels*8, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels*8), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
|
|
|
self.dec1 = nn.Sequential( |
|
|
nn.ConvTranspose2d(base_channels*8, base_channels*4, 2, stride=2), |
|
|
nn.BatchNorm2d(base_channels*4), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
self.dec2 = nn.Sequential( |
|
|
nn.ConvTranspose2d(base_channels*8, base_channels*2, 2, stride=2), |
|
|
nn.BatchNorm2d(base_channels*2), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
self.dec3 = nn.Sequential( |
|
|
nn.ConvTranspose2d(base_channels*4, base_channels, 2, stride=2), |
|
|
nn.BatchNorm2d(base_channels), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
|
|
|
self.output = nn.Conv2d(base_channels*2, 3, 3, padding=1) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
e1 = self.enc1(x) |
|
|
e2 = self.enc2(e1) |
|
|
e3 = self.enc3(e2) |
|
|
|
|
|
|
|
|
b = self.bottleneck(e3) |
|
|
|
|
|
|
|
|
d1 = self.dec1(b) |
|
|
d1 = torch.cat([d1, e2], dim=1) |
|
|
|
|
|
d2 = self.dec2(d1) |
|
|
d2 = torch.cat([d2, e1], dim=1) |
|
|
|
|
|
d3 = self.dec3(d2) |
|
|
|
|
|
output = torch.sigmoid(self.output(d3)) |
|
|
return output |
|
|
|
|
|
class AdvancedColorizationModel(nn.Module): |
|
|
def __init__(self, base_channels=64, use_attention=True): |
|
|
super().__init__() |
|
|
|
|
|
self.use_attention = use_attention |
|
|
|
|
|
|
|
|
self.enc1 = self._make_encoder_block(3, base_channels) |
|
|
self.enc2 = self._make_encoder_block(base_channels, base_channels*2) |
|
|
self.enc3 = self._make_encoder_block(base_channels*2, base_channels*4) |
|
|
self.enc4 = self._make_encoder_block(base_channels*4, base_channels*8) |
|
|
|
|
|
|
|
|
if use_attention: |
|
|
self.attention1 = nn.MultiheadAttention(base_channels*8, 8) |
|
|
self.attention2 = nn.MultiheadAttention(base_channels*4, 8) |
|
|
|
|
|
|
|
|
self.dec1 = self._make_decoder_block(base_channels*8, base_channels*4) |
|
|
self.dec2 = self._make_decoder_block(base_channels*8, base_channels*2) |
|
|
self.dec3 = self._make_decoder_block(base_channels*4, base_channels) |
|
|
self.dec4 = self._make_decoder_block(base_channels*2, base_channels//2) |
|
|
|
|
|
|
|
|
self.output = nn.Sequential( |
|
|
nn.Conv2d(base_channels, base_channels//2, 3, padding=1), |
|
|
nn.BatchNorm2d(base_channels//2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(base_channels//2, 3, 3, padding=1) |
|
|
) |
|
|
|
|
|
self.residual_conv = nn.Conv2d(3, 3, 1) |
|
|
|
|
|
def _make_encoder_block(self, in_ch, out_ch): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2) |
|
|
) |
|
|
|
|
|
def _make_decoder_block(self, in_ch, out_ch): |
|
|
return nn.Sequential( |
|
|
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(out_ch), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_input = x |
|
|
|
|
|
|
|
|
e1 = self.enc1(x) |
|
|
e2 = self.enc2(e1) |
|
|
e3 = self.enc3(e2) |
|
|
e4 = self.enc4(e3) |
|
|
|
|
|
|
|
|
if self.use_attention: |
|
|
|
|
|
b, c, h, w = e4.shape |
|
|
e4_flat = e4.view(b, c, -1).permute(2, 0, 1) |
|
|
e4_att, _ = self.attention1(e4_flat, e4_flat, e4_flat) |
|
|
e4 = e4_att.permute(1, 2, 0).view(b, c, h, w) |
|
|
|
|
|
|
|
|
d1 = self.dec1(e4) |
|
|
d1 = torch.cat([d1, e3], dim=1) |
|
|
|
|
|
d2 = self.dec2(d1) |
|
|
d2 = torch.cat([d2, e2], dim=1) |
|
|
|
|
|
d3 = self.dec3(d2) |
|
|
d3 = torch.cat([d3, e1], dim=1) |
|
|
|
|
|
d4 = self.dec4(d3) |
|
|
|
|
|
|
|
|
output = self.output(d4) |
|
|
residual = self.residual_conv(x_input) |
|
|
output = output + residual |
|
|
|
|
|
return torch.sigmoid(output) |
|
|
|
|
|
|
|
|
class SmartDataCollector: |
|
|
def __init__(self, settings): |
|
|
self.settings = settings |
|
|
self.scraper = cloudscraper.create_scraper() |
|
|
self.downloaded_count = 0 |
|
|
self.processed_galleries = set() |
|
|
self.load_processed_list() |
|
|
|
|
|
|
|
|
os.makedirs('training_data', exist_ok=True) |
|
|
os.makedirs('temp_downloads', exist_ok=True) |
|
|
|
|
|
def load_processed_list(self): |
|
|
"""โหลดรายการ gallery ที่ประมวลผลแล้ว""" |
|
|
try: |
|
|
if os.path.exists('processed_galleries.json'): |
|
|
with open('processed_galleries.json', 'r') as f: |
|
|
self.processed_galleries = set(json.load(f)) |
|
|
except: |
|
|
self.processed_galleries = set() |
|
|
|
|
|
def save_processed_list(self): |
|
|
"""บันทึกรายการ gallery ที่ประมวลผลแล้ว""" |
|
|
try: |
|
|
with open('processed_galleries.json', 'w') as f: |
|
|
json.dump(list(self.processed_galleries), f) |
|
|
except Exception as e: |
|
|
print(f"❌ บันทึกรายการล้มเหลว: {e}") |
|
|
|
|
|
def get_sample_image_urls(self): |
|
|
"""ให้ URL ภาพตัวอย่างสำหรับทดสอบ (หลีกเลี่ยงการดาวน์โหลดจริง)""" |
|
|
print("📝 ใช้ภาพตัวอย่างสำหรับทดสอบ...") |
|
|
|
|
|
|
|
|
sample_images = [] |
|
|
for i in range(20): |
|
|
|
|
|
color_img = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8) |
|
|
gray_img = cv2.cvtColor(color_img, cv2.COLOR_RGB2GRAY) |
|
|
gray_img = np.stack([gray_img]*3, axis=-1) |
|
|
|
|
|
|
|
|
color_path = f'training_data/sample_color_{i:03d}.jpg' |
|
|
gray_path = f'training_data/sample_gray_{i:03d}.jpg' |
|
|
|
|
|
cv2.imwrite(color_path, color_img) |
|
|
cv2.imwrite(gray_path, gray_img) |
|
|
|
|
|
sample_images.append((gray_path, color_path)) |
|
|
|
|
|
return sample_images |
|
|
|
|
|
def adaptive_data_collection(self): |
|
|
"""รวบรวมข้อมูลตามทรัพยากรที่มี""" |
|
|
max_galleries = self.settings['max_galleries_per_cycle'] |
|
|
|
|
|
print(f"📥 รวบรวมข้อมูล: {max_galleries} galleries") |
|
|
|
|
|
|
|
|
if self.settings['system_type'] == 'low_end': |
|
|
print("💡 ระบบ Low-end: ใช้ภาพตัวอย่าง") |
|
|
return self.get_sample_image_urls() |
|
|
else: |
|
|
|
|
|
try: |
|
|
return self.download_limited_data(max_galleries) |
|
|
except Exception as e: |
|
|
print(f"⚠️ ดาวน์โหลดล้มเหลว, ใช้ภาพตัวอย่าง: {e}") |
|
|
return self.get_sample_image_urls() |
|
|
|
|
|
def download_limited_data(self, max_galleries): |
|
|
"""ดาวน์โหลดข้อมูลจำนวนจำกัด""" |
|
|
|
|
|
image_pairs = [] |
|
|
|
|
|
|
|
|
try: |
|
|
print("🌐 ดาวน์โหลดภาพตัวอย่างจากแหล่งสาธารณะ...") |
|
|
|
|
|
|
|
|
sample_urls = [ |
|
|
"https://images.unsplash.com/photo-1541963463532-d68292c34b19", |
|
|
"https://images.unsplash.com/photo-1551963831-b3b1ca40c98e", |
|
|
"https://images.unsplash.com/photo-1551782450-a2132b4ba21d" |
|
|
] |
|
|
|
|
|
for i, url in enumerate(sample_urls[:3]): |
|
|
try: |
|
|
response = self.scraper.get(url, timeout=10) |
|
|
if response.status_code == 200: |
|
|
|
|
|
color_path = f'training_data/web_color_{i:03d}.jpg' |
|
|
with open(color_path, 'wb') as f: |
|
|
f.write(response.content) |
|
|
|
|
|
|
|
|
img = cv2.imread(color_path) |
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
|
gray_rgb = np.stack([gray]*3, axis=-1) |
|
|
gray_path = f'training_data/web_gray_{i:03d}.jpg' |
|
|
cv2.imwrite(gray_path, gray_rgb) |
|
|
|
|
|
image_pairs.append((gray_path, color_path)) |
|
|
print(f"✅ ดาวน์โหลดภาพ {i+1} สำเร็จ") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ ดาวน์โหลดภาพ {i+1} ล้มเหลว: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ การดาวน์โหลดล้มเหลว: {e}") |
|
|
|
|
|
|
|
|
if not image_pairs: |
|
|
image_pairs = self.get_sample_image_urls() |
|
|
|
|
|
return image_pairs |
|
|
|
|
|
|
|
|
class MangaDataset(Dataset): |
|
|
def __init__(self, image_pairs, target_size=256, augment=True): |
|
|
self.image_pairs = image_pairs |
|
|
self.target_size = target_size |
|
|
self.augment = augment |
|
|
|
|
|
|
|
|
if augment: |
|
|
self.transform = A.Compose([ |
|
|
A.HorizontalFlip(p=0.3), |
|
|
A.RandomRotate90(p=0.2), |
|
|
A.RandomBrightnessContrast(p=0.2), |
|
|
A.HueSaturationValue(p=0.2), |
|
|
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
else: |
|
|
self.transform = A.Compose([ |
|
|
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
|
ToTensorV2(), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_pairs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
try: |
|
|
gray_path, color_path = self.image_pairs[idx] |
|
|
|
|
|
|
|
|
gray_img = cv2.imread(gray_path) |
|
|
color_img = cv2.imread(color_path) |
|
|
|
|
|
if gray_img is None or color_img is None: |
|
|
raise ValueError("Cannot load image") |
|
|
|
|
|
|
|
|
gray_img = cv2.cvtColor(gray_img, cv2.COLOR_BGR2RGB) |
|
|
color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
gray_img = cv2.resize(gray_img, (self.target_size, self.target_size)) |
|
|
color_img = cv2.resize(color_img, (self.target_size, self.target_size)) |
|
|
|
|
|
|
|
|
if self.augment: |
|
|
transformed = self.transform(image=color_img, mask=gray_img) |
|
|
color_tensor = transformed['image'] |
|
|
gray_tensor = transformed['mask'] |
|
|
else: |
|
|
color_tensor = self.transform(image=color_img)['image'] |
|
|
gray_tensor = self.transform(image=gray_img)['image'] |
|
|
|
|
|
return gray_tensor, color_tensor |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ โหลดภาพล้มเหลว: {e}") |
|
|
|
|
|
dummy = torch.zeros(3, self.target_size, self.target_size) |
|
|
return dummy, dummy |
|
|
|
|
|
|
|
|
class AdaptiveTrainingSystem: |
|
|
def __init__(self): |
|
|
self.resource_manager = ResourceManager() |
|
|
self.settings = self.resource_manager.get_optimized_settings() |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
if self.settings['model_complexity'] == 'high': |
|
|
self.model = AdvancedColorizationModel(base_channels=64, use_attention=True) |
|
|
elif self.settings['model_complexity'] == 'medium': |
|
|
self.model = AdvancedColorizationModel(base_channels=48, use_attention=False) |
|
|
else: |
|
|
self.model = SimpleColorizationModel(base_channels=32) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
|
|
|
self.setup_optimizer() |
|
|
|
|
|
|
|
|
self.performance_metrics = { |
|
|
'loss_history': [], |
|
|
'memory_usage': [], |
|
|
'training_time': [] |
|
|
} |
|
|
|
|
|
print("🎯 โมเดลที่ใช้:", self.model.__class__.__name__) |
|
|
print("⚙️ การตั้งค่า:", self.settings) |
|
|
|
|
|
def setup_optimizer(self): |
|
|
"""ตั้งค่า optimizer ตามทรัพยากร""" |
|
|
if self.settings['system_type'] == 'high_end': |
|
|
self.optimizer = torch.optim.AdamW( |
|
|
self.model.parameters(), |
|
|
lr=0.001, |
|
|
weight_decay=0.01 |
|
|
) |
|
|
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
self.optimizer, |
|
|
T_max=self.settings['training_epochs'] |
|
|
) |
|
|
elif self.settings['system_type'] == 'medium_end': |
|
|
self.optimizer = torch.optim.Adam( |
|
|
self.model.parameters(), |
|
|
lr=0.001 |
|
|
) |
|
|
self.scheduler = torch.optim.lr_scheduler.StepLR( |
|
|
self.optimizer, |
|
|
step_size=5, |
|
|
gamma=0.8 |
|
|
) |
|
|
else: |
|
|
self.optimizer = torch.optim.Adam( |
|
|
self.model.parameters(), |
|
|
lr=0.0005 |
|
|
) |
|
|
self.scheduler = None |
|
|
|
|
|
def monitor_resources(self): |
|
|
"""ตรวจสอบทรัพยากรระหว่างการทำงาน""" |
|
|
try: |
|
|
memory_usage = psutil.virtual_memory().percent |
|
|
self.performance_metrics['memory_usage'].append(memory_usage) |
|
|
|
|
|
|
|
|
if memory_usage > 85: |
|
|
self.reduce_memory_usage() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
gpu_memory = torch.cuda.memory_allocated() / (1024**3) |
|
|
if hasattr(self, 'last_gpu_memory') and gpu_memory > self.last_gpu_memory * 1.5: |
|
|
self.reduce_batch_size() |
|
|
self.last_gpu_memory = gpu_memory |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ ตรวจสอบทรัพยากรล้มเหลว: {e}") |
|
|
|
|
|
def reduce_memory_usage(self): |
|
|
"""ลดการใช้ memory เมื่อจำเป็น""" |
|
|
print("⚠️ หน่วยความจำใกล้เต็ม, ลดการใช้ทรัพยากร...") |
|
|
|
|
|
|
|
|
if self.settings['batch_size'] > 1: |
|
|
self.settings['batch_size'] = max(1, self.settings['batch_size'] // 2) |
|
|
print(f"🔽 ลด batch size เป็น: {self.settings['batch_size']}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
def reduce_batch_size(self): |
|
|
"""ลด batch size สำหรับ GPU""" |
|
|
current_bs = self.settings['batch_size'] |
|
|
if current_bs > 1: |
|
|
self.settings['batch_size'] = current_bs // 2 |
|
|
print(f"🔽 ลด batch size เป็น: {self.settings['batch_size']}") |
|
|
|
|
|
def create_dataloader(self, image_pairs): |
|
|
"""สร้าง DataLoader""" |
|
|
dataset = MangaDataset( |
|
|
image_pairs, |
|
|
target_size=self.settings['image_size'], |
|
|
augment=True |
|
|
) |
|
|
|
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=self.settings['batch_size'], |
|
|
shuffle=True, |
|
|
num_workers=self.settings['dataloader_workers'], |
|
|
pin_memory=torch.cuda.is_available() |
|
|
) |
|
|
|
|
|
def train_epoch(self, dataloader, epoch): |
|
|
"""ฝึกหนึ่ง epoch""" |
|
|
self.model.train() |
|
|
total_loss = 0 |
|
|
criterion = nn.L1Loss() |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
for batch_idx, (gray, color) in enumerate(dataloader): |
|
|
|
|
|
self.monitor_resources() |
|
|
|
|
|
gray = gray.to(self.device) |
|
|
color = color.to(self.device) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
output = self.model(gray) |
|
|
loss = criterion(output, color) |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if self.settings['system_type'] == 'low_end': |
|
|
torch.nn.utils.clip_grad_norm_( |
|
|
self.model.parameters(), |
|
|
max_norm=1.0 |
|
|
) |
|
|
|
|
|
self.optimizer.step() |
|
|
total_loss += loss.item() |
|
|
|
|
|
if batch_idx % 5 == 0: |
|
|
print(f'📦 Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}') |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
epoch_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if self.scheduler: |
|
|
self.scheduler.step() |
|
|
|
|
|
self.performance_metrics['loss_history'].append(avg_loss) |
|
|
self.performance_metrics['training_time'].append(epoch_time) |
|
|
|
|
|
return avg_loss |
|
|
|
|
|
def save_checkpoint(self, epoch, loss): |
|
|
"""บันทึก checkpoint""" |
|
|
try: |
|
|
os.makedirs('models', exist_ok=True) |
|
|
|
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, |
|
|
'loss': loss, |
|
|
'settings': self.settings, |
|
|
'performance_metrics': self.performance_metrics |
|
|
} |
|
|
|
|
|
checkpoint_path = f"models/checkpoint_epoch_{epoch:04d}.pth" |
|
|
torch.save(checkpoint, checkpoint_path) |
|
|
|
|
|
|
|
|
if not hasattr(self, 'best_loss') or loss < self.best_loss: |
|
|
self.best_loss = loss |
|
|
torch.save(self.model.state_dict(), "models/best_model.pth") |
|
|
print(f"🏆 บันทึกโมเดลที่ดีที่สุด (Loss: {loss:.6f})") |
|
|
|
|
|
print(f"💾 บันทึก checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
self.cleanup_old_checkpoints() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ บันทึก checkpoint ล้มเหลว: {e}") |
|
|
|
|
|
def cleanup_old_checkpoints(self): |
|
|
"""ลบ checkpoint เก่า""" |
|
|
try: |
|
|
checkpoints = [f for f in os.listdir('models') if f.startswith('checkpoint_')] |
|
|
if len(checkpoints) > 3: |
|
|
checkpoints.sort() |
|
|
for old_checkpoint in checkpoints[:-3]: |
|
|
os.remove(os.path.join('models', old_checkpoint)) |
|
|
except Exception as e: |
|
|
print(f"⚠️ ลบ checkpoint เก่าล้มเหลว: {e}") |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path): |
|
|
"""โหลด checkpoint""" |
|
|
try: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
if self.scheduler and checkpoint['scheduler_state_dict']: |
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
|
|
self.settings.update(checkpoint.get('settings', {})) |
|
|
self.performance_metrics.update(checkpoint.get('performance_metrics', {})) |
|
|
|
|
|
print(f"✅ โหลด checkpoint: {checkpoint_path}") |
|
|
return checkpoint['epoch'] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ โหลด checkpoint ล้มเหลว: {e}") |
|
|
return 0 |
|
|
|
|
|
def colorize_image(self, image_path, output_path=None): |
|
|
"""ลงสีภาพเดี่ยว""" |
|
|
self.model.eval() |
|
|
|
|
|
try: |
|
|
|
|
|
image = cv2.imread(image_path) |
|
|
if image is None: |
|
|
raise ValueError(f"ไม่สามารถโหลดภาพ: {image_path}") |
|
|
|
|
|
|
|
|
if len(image.shape) == 3: |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
gray_rgb = np.stack([gray] * 3, axis=-1) |
|
|
else: |
|
|
gray_rgb = np.stack([image] * 3, axis=-1) |
|
|
|
|
|
|
|
|
h, w = gray_rgb.shape[:2] |
|
|
target_size = self.settings['image_size'] |
|
|
scale = target_size / max(h, w) |
|
|
new_h, new_w = int(h * scale), int(w * scale) |
|
|
|
|
|
gray_resized = cv2.resize(gray_rgb, (new_w, new_h)) |
|
|
|
|
|
|
|
|
gray_tensor = torch.from_numpy(gray_resized).permute(2, 0, 1).unsqueeze(0) |
|
|
gray_tensor = gray_tensor.float() / 255.0 |
|
|
gray_tensor = (gray_tensor - 0.5) / 0.5 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(gray_tensor.to(self.device)) |
|
|
output = output.squeeze(0).cpu().numpy() |
|
|
output = np.transpose(output, (1, 2, 0)) |
|
|
output = (output * 0.5 + 0.5) * 255 |
|
|
output = output.astype(np.uint8) |
|
|
|
|
|
|
|
|
output = cv2.resize(output, (w, h)) |
|
|
|
|
|
if output_path: |
|
|
cv2.imwrite(output_path, cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) |
|
|
print(f"✅ บันทึกภาพลงสี: {output_path}") |
|
|
|
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ ลงสีภาพล้มเหลว: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
class AutoMangaColorizationSystem: |
|
|
def __init__(self): |
|
|
print("🚀 เริ่มต้นระบบลงสีมังงะอัตโนมัติ...") |
|
|
print("=" * 50) |
|
|
|
|
|
self.training_system = AdaptiveTrainingSystem() |
|
|
self.data_collector = SmartDataCollector(self.training_system.settings) |
|
|
self.cycle_count = 0 |
|
|
|
|
|
|
|
|
self.load_latest_checkpoint() |
|
|
|
|
|
def load_latest_checkpoint(self): |
|
|
"""โหลด checkpoint ล่าสุด""" |
|
|
try: |
|
|
if os.path.exists('models'): |
|
|
checkpoints = [f for f in os.listdir('models') if f.startswith('checkpoint_')] |
|
|
if checkpoints: |
|
|
latest_checkpoint = max(checkpoints) |
|
|
checkpoint_path = os.path.join('models', latest_checkpoint) |
|
|
self.start_epoch = self.training_system.load_checkpoint(checkpoint_path) + 1 |
|
|
print(f"🔄 ดำเนินการต่อจาก epoch {self.start_epoch}") |
|
|
return |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ โหลด checkpoint ล้มเหลว: {e}") |
|
|
|
|
|
self.start_epoch = 0 |
|
|
|
|
|
def run_continuous_learning(self, total_cycles=100): |
|
|
"""รันระบบเรียนรู้ต่อเนื่อง""" |
|
|
print(f"🎯 เริ่มการเรียนรู้ {total_cycles} วงจร...") |
|
|
|
|
|
for cycle in range(self.start_epoch, total_cycles): |
|
|
self.cycle_count = cycle |
|
|
print(f"\n{'='*60}") |
|
|
print(f"🔄 วงจรที่ {cycle + 1}/{total_cycles}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
try: |
|
|
|
|
|
print("📥 กำลังรวบรวมข้อมูล...") |
|
|
image_pairs = self.data_collector.adaptive_data_collection() |
|
|
|
|
|
if not image_pairs: |
|
|
print("⚠️ ไม่พบข้อมูล, ข้ามการฝึกในวงจรนี้") |
|
|
continue |
|
|
|
|
|
print(f"📊 ได้ข้อมูล {len(image_pairs)} คู่ภาพ") |
|
|
|
|
|
|
|
|
dataloader = self.training_system.create_dataloader(image_pairs) |
|
|
|
|
|
|
|
|
print("🎯 เริ่มการฝึก...") |
|
|
avg_loss = self.training_system.train_epoch(dataloader, cycle) |
|
|
|
|
|
print(f"✅ วงจร {cycle + 1} เสร็จสิ้น, Loss: {avg_loss:.6f}") |
|
|
|
|
|
|
|
|
if cycle % 3 == 0: |
|
|
self.training_system.save_checkpoint(cycle, avg_loss) |
|
|
|
|
|
|
|
|
if cycle % 5 == 0 and image_pairs: |
|
|
self.test_current_model(cycle, image_pairs[0][0]) |
|
|
|
|
|
|
|
|
self.cleanup_resources() |
|
|
|
|
|
|
|
|
self.adaptive_sleep(cycle) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ วงจร {cycle} ล้มเหลว: {e}") |
|
|
self.system_recovery() |
|
|
|
|
|
def test_current_model(self, cycle, test_image_path): |
|
|
"""ทดสอบโมเดลปัจจุบัน""" |
|
|
try: |
|
|
if os.path.exists(test_image_path): |
|
|
output_path = f"results/test_cycle_{cycle:04d}.jpg" |
|
|
os.makedirs('results', exist_ok=True) |
|
|
|
|
|
result = self.training_system.colorize_image(test_image_path, output_path) |
|
|
if result is not None: |
|
|
print(f"🧪 ทดสอบโมเดลวงจร {cycle} สำเร็จ") |
|
|
else: |
|
|
print(f"⚠️ การทดสอบล้มเหลว") |
|
|
else: |
|
|
print("⚠️ ไม่พบไฟล์ทดสอบ") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ การทดสอบล้มเหลว: {e}") |
|
|
|
|
|
def cleanup_resources(self): |
|
|
"""ทำความสะอาดทรัพยากร""" |
|
|
try: |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
temp_files = [f for f in os.listdir('temp_downloads') if f.endswith('.tmp')] |
|
|
for temp_file in temp_files: |
|
|
try: |
|
|
os.remove(os.path.join('temp_downloads', temp_file)) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ ทำความสะอาดทรัพยากรล้มเหลว: {e}") |
|
|
|
|
|
def adaptive_sleep(self, cycle): |
|
|
"""พักระหว่างวงจรตามทรัพยากร""" |
|
|
system_type = self.training_system.settings['system_type'] |
|
|
|
|
|
if system_type == 'high_end': |
|
|
sleep_time = 120 |
|
|
elif system_type == 'medium_end': |
|
|
sleep_time = 180 |
|
|
else: |
|
|
sleep_time = 300 |
|
|
|
|
|
print(f"⏳ พัก {sleep_time} วินาที...") |
|
|
|
|
|
|
|
|
for i in range(sleep_time // 30): |
|
|
time.sleep(30) |
|
|
self.training_system.monitor_resources() |
|
|
|
|
|
def system_recovery(self): |
|
|
"""กู้คืนระบบเมื่อเกิดข้อผิดพลาด""" |
|
|
print("🔄 พยายามกู้คืนระบบ...") |
|
|
|
|
|
try: |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
self.training_system = AdaptiveTrainingSystem() |
|
|
|
|
|
print("✅ กู้คืนระบบสำเร็จ") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ กู้คืนระบบล้มเหลว: {e}") |
|
|
|
|
|
def colorize_single_image(self, image_path, output_path=None): |
|
|
"""ลงสีภาพเดี่ยว""" |
|
|
if output_path is None: |
|
|
output_path = f"colored_{os.path.basename(image_path)}" |
|
|
|
|
|
print(f"🎨 กำลังลงสีภาพ: {image_path}") |
|
|
result = self.training_system.colorize_image(image_path, output_path) |
|
|
|
|
|
if result is not None: |
|
|
print(f"✅ ลงสีภาพสำเร็จ: {output_path}") |
|
|
return True |
|
|
else: |
|
|
print("❌ ลงสีภาพล้มเหลว") |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""ฟังก์ชันหลัก""" |
|
|
print("🎨 Manga Colorization System") |
|
|
print("สร้างโดย: Auto AI System") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
|
|
|
system = AutoMangaColorizationSystem() |
|
|
|
|
|
|
|
|
print("\n🎯 เริ่มกระบวนการเรียนรู้...") |
|
|
system.run_continuous_learning(total_cycles=50) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n⏹️ หยุดระบบโดยผู้ใช้") |
|
|
except Exception as e: |
|
|
print(f"\n💥 ระบบหยุดทำงาน: {e}") |
|
|
finally: |
|
|
print("\n🧹 ทำความสะอาดทรัพยากร...") |
|
|
print("✅ โปรแกรมสิ้นสุดการทำงาน") |
|
|
|
|
|
def quick_test(): |
|
|
"""ทดสอบระบบอย่างรวดเร็ว""" |
|
|
print("🧪 โหมดทดสอบอย่างรวดเร็ว...") |
|
|
|
|
|
try: |
|
|
system = AutoMangaColorizationSystem() |
|
|
|
|
|
|
|
|
print("🔬 ทดสอบ 1 วงจร...") |
|
|
image_pairs = system.data_collector.adaptive_data_collection() |
|
|
|
|
|
if image_pairs: |
|
|
dataloader = system.training_system.create_dataloader(image_pairs) |
|
|
loss = system.training_system.train_epoch(dataloader, 0) |
|
|
print(f"📊 Loss: {loss:.6f}") |
|
|
|
|
|
|
|
|
if image_pairs: |
|
|
test_image = image_pairs[0][0] |
|
|
system.test_current_model(0, test_image) |
|
|
|
|
|
print("✅ การทดสอบสำเร็จ") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ การทดสอบล้มเหลว: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Manga Colorization System') |
|
|
parser.add_argument('--test', action='store_true', help='โหมดทดสอบอย่างรวดเร็ว') |
|
|
parser.add_argument('--colorize', type=str, help='ลงสีภาพเดี่ยว') |
|
|
parser.add_argument('--output', type=str, help='ไฟล์ผลลัพธ์สำหรับโหมด colorize') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.test: |
|
|
quick_test() |
|
|
elif args.colorize: |
|
|
system = AutoMangaColorizationSystem() |
|
|
system.colorize_single_image(args.colorize, args.output) |
|
|
else: |
|
|
main() |