|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
def load_image(image_path, is_line_art=False): |
|
|
"""โหลดภาพและประมวลผลเบื้องต้น""" |
|
|
image = tf.io.read_file(image_path) |
|
|
image = tf.image.decode_image(image, channels=3) |
|
|
image = tf.image.resize(image, [256, 256]) |
|
|
image = tf.cast(image, tf.float32) / 255.0 |
|
|
|
|
|
if is_line_art: |
|
|
|
|
|
image = tf.image.rgb_to_grayscale(image) |
|
|
|
|
|
image = tf.where(image < 0.5, 0.0, 1.0) |
|
|
|
|
|
return image |
|
|
|
|
|
def extract_line_art_from_colored(colored_image): |
|
|
"""สกัดเส้นจากภาพสี (ใช้สร้าง training data)""" |
|
|
|
|
|
gray = cv2.cvtColor(colored_image, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
|
|
|
line_art = cv2.adaptiveThreshold( |
|
|
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
|
|
cv2.THRESH_BINARY_INV, 11, 2 |
|
|
) |
|
|
|
|
|
|
|
|
kernel = np.ones((2, 2), np.uint8) |
|
|
line_art = cv2.morphologyEx(line_art, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
return line_art |
|
|
|
|
|
def save_comparison(image, prediction, epoch, step): |
|
|
"""บันทึกภาพเปรียบเทียบ - แก้ไขเวอร์ชัน""" |
|
|
try: |
|
|
|
|
|
input_image = image[0].numpy() |
|
|
pred_image = prediction[0].numpy() |
|
|
|
|
|
|
|
|
if input_image.shape[-1] == 1: |
|
|
input_image_rgb = np.repeat(input_image, 3, axis=-1) |
|
|
else: |
|
|
input_image_rgb = input_image |
|
|
|
|
|
|
|
|
combined = np.concatenate([input_image_rgb, pred_image], axis=1) |
|
|
|
|
|
|
|
|
combined = np.clip(combined, 0, 1) |
|
|
combined = (combined * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
comparison = Image.fromarray(combined) |
|
|
os.makedirs('output', exist_ok=True) |
|
|
comparison.save(f'output/epoch_{epoch:03d}_step_{step:03d}.png') |
|
|
print(f"💾 บันทึกภาพเปรียบเทียบ: output/epoch_{epoch:03d}_step_{step:03d}.png") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ ข้อผิดพลาดในการบันทึกภาพเปรียบเทียบ: {e}") |
|
|
|
|
|
def prepare_directories(): |
|
|
"""สร้าง directory ที่จำเป็น""" |
|
|
os.makedirs('train_images/line_arts', exist_ok=True) |
|
|
os.makedirs('train_images/colored', exist_ok=True) |
|
|
os.makedirs('output', exist_ok=True) |
|
|
print("✅ สร้าง directory เรียบร้อย") |
|
|
|
|
|
|
|
|
def check_dataset_size(line_art_dir, colored_dir): |
|
|
"""ตรวจสอบจำนวนไฟล์ใน dataset""" |
|
|
line_art_files = [f for f in os.listdir(line_art_dir) if f.endswith(('.png', '.jpg'))] |
|
|
colored_files = [f for f in os.listdir(colored_dir) if f.endswith(('.png', '.jpg'))] |
|
|
|
|
|
print(f"📁 ภาพเส้น: {len(line_art_files)} ไฟล์") |
|
|
print(f"🎨 ภาพสี: {len(colored_files)} ไฟล์") |
|
|
|
|
|
if len(line_art_files) != len(colored_files): |
|
|
print("⚠️ จำนวนภาพเส้นและภาพสีไม่เท่ากัน!") |
|
|
return False |
|
|
return True |