how to finetune

#1
by buaadwxl - opened

Thank you for your great work! I'm confused about how to perform finetuning. Should it be entirely based on the original weights of Trellis? During training, should I freeze some parameters or train all parameters? What is the learning rate for training? And what is the total number of training steps?

Hey!

Yes, I initialized all models with the original weights from the pretrained TRELLIS models. I only finetuned the flow matching transformer for the sparse structure (ss_flow) and the structured latents (slat_flow), for the structure generation and structured latents generation respectively. I also only finetuned one model at a time and kept all other models frozen. I did not test finetuning the Decoder for the Meshes or Gaussian Splats of the VAE. My hope is that it is similar to finetuning stable diffusion models, where most people focus on the latent diffusion model and not the vae. The training routine itself is based on the official repository and I only implemented a finetuning routine around that so I can load the safetensor files before I finetune one of the models. I tested different hyperparameters for both models:

Sparse Structure (ss_flow)

You can find a wandb report of the sweep for the sparse structured (ss_flow) here: https://api.wandb.ai/links/damian-boborzi/k0aoqwzv
At the bottom you can find the influence of the different parameter setups on the Sparse Structure reconstruction Quality (measured as MSE).
However, finetuning this model did not seem to improve the performance in a significant way. The second model hat a greater impact on the final generation quality.

Structured Latents (slat_flow)

You can find a wandb report of the sweep for the structured latent (slat_flow) here: https://api.wandb.ai/links/damian-boborzi/esloucbn
At the bottom you can find the influence of the different parameter setups on the quality based on the CLIP-S using rendered images.
This setup performed well for example:

  • base_config:"./configs/generation/slat_flow_txt_dit_XL_64l8p2_fp16_finetune.json"
  • ema_rate:0.9992401601507136
  • finetune_ckpt:"./assets/pretrained_TRELLIS_txt_xl/snapshots/e0b00432b8e3a8ecee0df806ab1df9f7281f2be4/ckpts/slat_flow_txt_dit_XL_64l8p2_fp16.safetensors"
  • learning_rate:0.00003490315455061004
  • max_steps:250,000
  • output_dir_base:"./outputs/sweeps_slat_flow_txt"
  • p_uncond:0.10131071004600269
  • weight_decay:0.04869847905146491

I also want to add that we noticed that the image conditioned model already performs really well and finetuning might only help a bit. At least if the finetuning data is from a domain which is already available in TRELLIS500K. The impact on finetuning the text based models was much greater.

We plan on releasing the finetuning code very soon, but I hope this helps in the meantime. If you have any further questions feel free to ask :)

Thank you for your detailed reply. Could you please share some code examples on "how to load the safetensor files before fine-tuning one of the models"? I am currently trying to fine-tune the Trellis model (image to 3D), but the training results are collapsing. I would appreciate your help.

I'm currently troubled by this issue: when I compute the absolute difference (abs) between the weights you provided (the fine-tuned ones) and the original Trellis weights, the difference is around 1e-3. However, when I compute the abs difference between the weights from my own fine-tuning and the original Trellis weights, the difference is around 1e-2. This significant discrepancy is causing my inference to fail.

Sign up or log in to comment