|
|
import tensorflow as tf |
|
|
from model import MangaColorizer |
|
|
from dataset import MangaDataset |
|
|
from utils import save_comparison, prepare_directories, check_dataset_size |
|
|
from config import config |
|
|
import matplotlib.pyplot as plt |
|
|
import os |
|
|
|
|
|
class TrainingMonitor(tf.keras.callbacks.Callback): |
|
|
"""Callback สำหรับตรวจสอบการฝึก""" |
|
|
def __init__(self, dataset): |
|
|
super().__init__() |
|
|
self.dataset = dataset |
|
|
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
|
|
|
|
for input_image, target_image in self.dataset.take(1): |
|
|
prediction = self.model.generator(input_image, training=False) |
|
|
save_comparison(input_image, prediction, epoch, 0) |
|
|
|
|
|
|
|
|
if logs: |
|
|
print(f"📊 Epoch {epoch+1}: " |
|
|
f"Gen Loss: {logs.get('gen_total_loss', 0):.4f}, " |
|
|
f"Disc Loss: {logs.get('disc_loss', 0):.4f}") |
|
|
|
|
|
def main(): |
|
|
|
|
|
prepare_directories() |
|
|
|
|
|
|
|
|
if not check_dataset_size(config.LINE_ART_DIR, config.COLORED_DIR): |
|
|
print("❌ กรุณาตรวจสอบ dataset ก่อนเริ่มฝึก") |
|
|
return |
|
|
|
|
|
|
|
|
print("🔄 กำลังโหลดข้อมูล...") |
|
|
dataset = MangaDataset().load_data() |
|
|
|
|
|
|
|
|
try: |
|
|
sample_batch = next(iter(dataset)) |
|
|
print(f"✅ โหลดข้อมูลสำเร็จ: Batch size {sample_batch[0].shape}") |
|
|
except StopIteration: |
|
|
print("❌ ไม่มีข้อมูลใน dataset") |
|
|
return |
|
|
|
|
|
|
|
|
print("🔄 กำลังสร้างโมเดล...") |
|
|
colorizer = MangaColorizer() |
|
|
|
|
|
|
|
|
generator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5) |
|
|
discriminator_optimizer = tf.keras.optimizers.Adam(config.LEARNING_RATE, beta_1=0.5) |
|
|
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True) |
|
|
|
|
|
colorizer.compile( |
|
|
g_optimizer=generator_optimizer, |
|
|
d_optimizer=discriminator_optimizer, |
|
|
loss_fn=loss_fn |
|
|
) |
|
|
|
|
|
print(f"✅ สร้างโมเดลสำเร็จ") |
|
|
print(f"📈 เริ่มฝึก {config.EPOCHS} epochs...") |
|
|
|
|
|
|
|
|
history = colorizer.fit( |
|
|
dataset, |
|
|
epochs=config.EPOCHS, |
|
|
callbacks=[TrainingMonitor(dataset)], |
|
|
verbose=1 |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs('output', exist_ok=True) |
|
|
colorizer.generator.save('output/manga_colorizer.h5') |
|
|
print("✅ บันทึกโมเดลเรียบร้อย: output/manga_colorizer.h5") |
|
|
|
|
|
|
|
|
try: |
|
|
plt.figure(figsize=(12, 4)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
if 'gen_total_loss' in history.history: |
|
|
plt.plot(history.history['gen_total_loss'], label='Generator Total Loss') |
|
|
if 'gen_gan_loss' in history.history: |
|
|
plt.plot(history.history['gen_gan_loss'], label='Generator GAN Loss', linestyle='--') |
|
|
if 'gen_l1_loss' in history.history: |
|
|
plt.plot(history.history['gen_l1_loss'], label='Generator L1 Loss', linestyle=':') |
|
|
plt.title('Generator Loss') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.legend() |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
if 'disc_loss' in history.history: |
|
|
plt.plot(history.history['disc_loss'], label='Discriminator Loss', color='red') |
|
|
plt.title('Discriminator Loss') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('output/training_loss.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("✅ บันทึกกราฟ training loss: output/training_loss.png") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ ไม่สามารถบันทึกกราฟ loss: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |