YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Flow Matching fMRI Encoder

A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching.


Overview

The pipeline decodes brain activity in two sequential stages:

Stage Model Goal
1 MultiSubjectConvLinearEncoder Predict a Mean Anchor β€” a deterministic per-voxel fMRI estimate shared across subjects
2 CFM (Conditional Flow Matching) Learn a per-subject neural vector field that refines the Mean Anchor into a sharper, stochastic fMRI prediction

The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations.


Architecture

Stage 1 β€” Mean Anchor (medarc_architecture.py)

Features (N, T, D_i)  β†’  DepthConv1d + Linear  β†’  embed (N, T, d)
                                                         ↓
                                               [Shared Decoder]  +  [Subject Decoders]
                                                         ↓
                                               fMRI Prediction (N, S, T, V)
  • MultiSubjectConvLinearEncoder: Projects each feature stream through a LinearConv (depthwise conv + linear) to a shared embedding dimension d = 192.
  • Global average pooling across the feature stack combines multi-model features.
  • A shared linear decoder and per-subject linear decoders combine additively to predict V voxels for each of S subjects.
  • Trained with MSE loss against ground-truth BOLD.

Stage 2 β€” Neural Vector Field (matcha_architecture.py)

x1 (B, V, T)  β†’  proj_in (V β†’ d)  β†’  latent x1 (B, d, T)
mu (B, V, T)  β†’  proj_in (V β†’ d)  β†’  latent mu (B, d, T)
                                              ↓
                              OT-CFM loss (vector field u_t)
                              Matcha-TTS U-Net estimator
                              [Conformer / Transformer blocks]
                                              ↓
                              latent pred (B, d, T)  β†’  proj_out (d β†’ V)  β†’  fMRI (B, V, T)
  • Latent Bottleneck: fMRI voxels (V β‰ˆ 1000) are projected down to a dense latent dimension (d = 128) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse.
  • CFM wraps a Matcha-TTS style 1D U-Net (Decoder) with ResNet-1D blocks and Conformer/Transformer attention at each scale.
  • At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$.
  • An auxiliary reconstruction loss (weight 0.1) on proj_in β†’ proj_out trains the projection pair jointly with the vector field.

Data

Source Content Usage
Friends (seasons 1–7) fMRI BOLD + multimodal features Train (S1), Val (S6, S7)
Movie10 fMRI BOLD + multimodal features Supplementary val (Figures, Life, Bourne, Wolf)

Subjects used: 1, 2, 3, 5.

Feature Models

The encoder can ingest intermediate activations from any combination of:

Key Model
internvl3_8b, internvl3_14b InternVL3 vision-language model
qwen-2-5-omni-3b, qwen-2-5-omni-7b Qwen2.5-Omni audio-video model
whisper OpenAI Whisper (audio)
llama_3.2_1b, llama_3.2_3b LLaMA 3.2 (text)
vjepa2 V-JEPA 2 (video)

Active features are set in config.yml under include_features.


File Structure

flow_matching/
β”œβ”€β”€ config.yml               # Main training config (GPU, full data)
β”œβ”€β”€ debug_config.yml         # Fast local debug config (CPU, tiny data)
β”œβ”€β”€ environment.yml          # Conda environment spec
β”‚
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ training.py          # Two-stage training loop + evaluation
β”‚   β”œβ”€β”€ matcha_architecture.py  # CFM + Matcha-TTS U-Net decoder
β”‚   β”œβ”€β”€ medarc_architecture.py  # Stage 1 MultiSubjectConvLinearEncoder
β”‚   β”œβ”€β”€ data.py              # Algonauts2025 dataset + loaders
β”‚   β”œβ”€β”€ metric.py            # Pearson's r voxel-wise scoring
β”‚   β”œβ”€β”€ visualize.py         # Loss curve plotting
β”‚   └── inference.py         # Standalone inference helper
β”‚
β”œβ”€β”€ test/
β”‚   β”œβ”€β”€ overfit_test.py      # Tiny-batch overfit sanity check for Stage 2
β”‚   β”œβ”€β”€ check_pearson.py     # Load checkpoints and plot per-voxel Pearson's r heatmaps
β”‚   └── debug_training.py    # End-to-end smoke test
β”‚
β”œβ”€β”€ experiments/
β”‚   └── *.ipynb              # Analysis notebooks (RSA, OOD, brain region plots)
β”‚
└── Matcha-TTS/              # Vendored Matcha-TTS source (U-Net + solver)

Training

Full training (server)

cd flow_matching
python src/training.py --cfg-path config.yml

Checkpoints are written to output/two_stage_encoding/:

  • stage1_best.pt β€” best Stage 1 model by validation Pearson's r
  • stage2_epoch_N.pt β€” Stage 2 snapshot every 5 epochs

Local debug (CPU, tiny model)

python src/training.py --cfg-path debug_config.yml

Evaluation

Pearson's r heatmaps

Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to output/two_stage_encoding/heatmaps/.

python test/check_pearson.py

Output per checkpoint:

Stage 1 Overall Pearson's r: 0.1832
Stage 1 - Sub 1 Mean Pearson's r: 0.1754
Stage 2 Epoch 5 Overall Pearson's r: 0.2110
Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043
...

Tiny-batch overfit test

Confirms Stage 2 can memorize a single training batch. If loss does not approach 0 within 500 steps, the architecture cannot learn the task.

python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500

Key Hyperparameters

Parameter Value Location
Stage 1 embed dim 192 config.yml / stage1.model.embed_dim
Stage 1 encoder kernel 45 config.yml / stage1.model.encoder_kernel_size
Stage 1 LR 3e-4 config.yml / stage1.lr
Stage 2 latent dim 128 config.yml / stage2.latent_dim
Stage 2 U-Net channels [256, 256] config.yml / stage2.decoder.channels
Stage 2 block type Conformer config.yml / stage2.decoder.*_block_type
Stage 2 LR 3e-4 config.yml / stage2.lr
Euler integration steps 25 config.yml / stage2.n_timesteps
CFM Οƒ_min 1e-4 config.yml / stage2.cfm.sigma_min

Metric

Evaluation uses voxel-wise Pearson's r averaged across subjects:

rv=βˆ‘t(yvtβˆ’yΛ‰v)(y^vtβˆ’y^Λ‰v)βˆ₯yvβˆ₯β‹…βˆ₯y^vβˆ₯r_v = \frac{\sum_t (y_v^t - \bar{y}_v)(\hat{y}_v^t - \bar{\hat{y}}_v)}{\|\mathbf{y}_v\| \cdot \|\hat{\mathbf{y}}_v\|}

The scalar reported is the mean over all V voxels and all S subjects.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support