YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

NSGF++ β€” Neural Sinkhorn Gradient Flow

Reproduction of arXiv:2401.14069

Setup

git clone https://huggingface.co/rogermt/nsgf-plusplus
cd nsgf-plusplus
pip install torch torchvision numpy scipy scikit-learn matplotlib geomloss pot tqdm pyyaml
# For GPU acceleration of Sinkhorn: pip install pykeops

Quick start β€” 2D experiments

# Full-scale 8gaussians (paper Table 1, ~10 min on GPU)
python main.py --experiment 2d --dataset 8gaussians --steps 10

# Quick test (< 1 min)
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 10 --train-iters 1000

# All 2D datasets
for ds in 8gaussians moons scurve checkerboard; do
  python main.py --experiment 2d --dataset $ds --steps 10
  python main.py --experiment 2d --dataset $ds --steps 100
done

Image experiments (NSGF++)

# MNIST (paper: FID=3.8, NFE=60)
python main.py --experiment mnist

# CIFAR-10 (paper: FID=5.55, IS=8.86, NFE=59)
python main.py --experiment cifar10

Files

File Description
config.yaml All hyperparameters from the paper
main.py CLI entry point
dataset_loader.py 2D synthetic + MNIST/CIFAR-10 loaders
sinkhorn_flow.py Sinkhorn potentials (GeomLoss), gradient flow, trajectory pool
model.py VelocityMLP (2D), VelocityUNet (images), PhaseTransitionPredictor
trainer.py NSGF, NSF, phase predictor, and NSGF++ trainers
inference.py NSGF and NSGF++ samplers
evaluation.py W2 distance, FID, IS, visualization

Paper targets

Experiment Metric Target
8gaussians / 10 steps W2 0.285
MNIST FID / NFE 3.8 / 60
CIFAR-10 FID / IS / NFE 5.55 / 8.86 / 59
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for rogermt/nsgf-plusplus