import cv2 import numpy as np import tensorflow as tf from PIL import Image import os def load_image(image_path, is_line_art=False): """โหลดภาพและประมวลผลเบื้องต้น""" image = tf.io.read_file(image_path) image = tf.image.decode_image(image, channels=3) image = tf.image.resize(image, [256, 256]) image = tf.cast(image, tf.float32) / 255.0 if is_line_art: # แปลงเป็น grayscale และทำให้เส้นคมชัด image = tf.image.rgb_to_grayscale(image) # ทำให้เส้นดำสนิท พื้นขาว image = tf.where(image < 0.5, 0.0, 1.0) return image def extract_line_art_from_colored(colored_image): """สกัดเส้นจากภาพสี (ใช้สร้าง training data)""" # แปลงเป็น grayscale gray = cv2.cvtColor(colored_image, cv2.COLOR_RGB2GRAY) # ใช้ adaptive threshold เพื่อให้เส้นคมชัด line_art = cv2.adaptiveThreshold( gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2 ) # ลบ noise kernel = np.ones((2, 2), np.uint8) line_art = cv2.morphologyEx(line_art, cv2.MORPH_OPEN, kernel) return line_art def save_comparison(image, prediction, epoch, step): """บันทึกภาพเปรียบเทียบ - แก้ไขเวอร์ชัน""" try: # แปลง TensorFlow tensor เป็น numpy array input_image = image[0].numpy() # (256, 256, 1) pred_image = prediction[0].numpy() # (256, 256, 3) # แปลงภาพ input จาก grayscale เป็น RGB if input_image.shape[-1] == 1: input_image_rgb = np.repeat(input_image, 3, axis=-1) else: input_image_rgb = input_image # รวมภาพ input และ prediction ข้างกัน combined = np.concatenate([input_image_rgb, pred_image], axis=1) # คลิปค่าและแปลงเป็น uint8 combined = np.clip(combined, 0, 1) combined = (combined * 255).astype(np.uint8) # บันทึกภาพ comparison = Image.fromarray(combined) os.makedirs('output', exist_ok=True) comparison.save(f'output/epoch_{epoch:03d}_step_{step:03d}.png') print(f"💾 บันทึกภาพเปรียบเทียบ: output/epoch_{epoch:03d}_step_{step:03d}.png") except Exception as e: print(f"❌ ข้อผิดพลาดในการบันทึกภาพเปรียบเทียบ: {e}") def prepare_directories(): """สร้าง directory ที่จำเป็น""" os.makedirs('train_images/line_arts', exist_ok=True) os.makedirs('train_images/colored', exist_ok=True) os.makedirs('output', exist_ok=True) print("✅ สร้าง directory เรียบร้อย") # ฟังก์ชันเพิ่มเติมสำหรับการตรวจสอบข้อมูล def check_dataset_size(line_art_dir, colored_dir): """ตรวจสอบจำนวนไฟล์ใน dataset""" line_art_files = [f for f in os.listdir(line_art_dir) if f.endswith(('.png', '.jpg'))] colored_files = [f for f in os.listdir(colored_dir) if f.endswith(('.png', '.jpg'))] print(f"📁 ภาพเส้น: {len(line_art_files)} ไฟล์") print(f"🎨 ภาพสี: {len(colored_files)} ไฟล์") if len(line_art_files) != len(colored_files): print("⚠️ จำนวนภาพเส้นและภาพสีไม่เท่ากัน!") return False return True