Image-to-Image
Adapters
chemistry
art
Manga / train.py
K1Z3M1112's picture
Upload 6 files
69d5ab4 verified
import tensorflow as tf
from model import MangaColorizer
from dataset import MangaDataset
from utils import save_comparison, prepare_directories, check_dataset_size
from config import config
import matplotlib.pyplot as plt
import os
class TrainingMonitor(tf.keras.callbacks.Callback):
"""Callback สำหรับตรวจสอบการฝึก"""
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def on_epoch_end(self, epoch, logs=None):
# ใช้ภาพตัวอย่างจาก batch แรก
for input_image, target_image in self.dataset.take(1):
prediction = self.model.generator(input_image, training=False)
save_comparison(input_image, prediction, epoch, 0)
# พิมพ์ loss ทุก epoch
if logs:
print(f"📊 Epoch {epoch+1}: "
f"Gen Loss: {logs.get('gen_total_loss', 0):.4f}, "
f"Disc Loss: {logs.get('disc_loss', 0):.4f}")
def main():
# เตรียม directory
prepare_directories()
# ตรวจสอบ dataset
if not check_dataset_size(config.LINE_ART_DIR, config.COLORED_DIR):
print("❌ กรุณาตรวจสอบ dataset ก่อนเริ่มฝึก")
return
# โหลดข้อมูล
print("🔄 กำลังโหลดข้อมูล...")
dataset = MangaDataset().load_data()
# ตรวจสอบว่ามีข้อมูลหรือไม่
try:
sample_batch = next(iter(dataset))
print(f"✅ โหลดข้อมูลสำเร็จ: Batch size {sample_batch[0].shape}")
except StopIteration:
print("❌ ไม่มีข้อมูลใน dataset")
return
# สร้างโมเดล
print("🔄 กำลังสร้างโมเดล...")
colorizer = MangaColorizer()
# Compile โมเดล
generator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
colorizer.compile(
g_optimizer=generator_optimizer,
d_optimizer=discriminator_optimizer,
loss_fn=loss_fn
)
print(f"✅ สร้างโมเดลสำเร็จ")
print(f"📈 เริ่มฝึก {config.EPOCHS} epochs...")
# ฝึกโมเดล
history = colorizer.fit(
dataset,
epochs=config.EPOCHS,
callbacks=[TrainingMonitor(dataset)],
verbose=1
)
# บันทึกโมเดล
os.makedirs('output', exist_ok=True)
colorizer.generator.save('output/manga_colorizer.h5')
print("✅ บันทึกโมเดลเรียบร้อย: output/manga_colorizer.h5")
# พล็อตกราฟ loss
try:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
if 'gen_total_loss' in history.history:
plt.plot(history.history['gen_total_loss'], label='Generator Total Loss')
if 'gen_gan_loss' in history.history:
plt.plot(history.history['gen_gan_loss'], label='Generator GAN Loss', linestyle='--')
if 'gen_l1_loss' in history.history:
plt.plot(history.history['gen_l1_loss'], label='Generator L1 Loss', linestyle=':')
plt.title('Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
if 'disc_loss' in history.history:
plt.plot(history.history['disc_loss'], label='Discriminator Loss', color='red')
plt.title('Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig('output/training_loss.png', dpi=300, bbox_inches='tight')
plt.close()
print("✅ บันทึกกราฟ training loss: output/training_loss.png")
except Exception as e:
print(f"❌ ไม่สามารถบันทึกกราฟ loss: {e}")
if __name__ == "__main__":
main()