YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Flow Matching fMRI Encoder
A two-stage architecture for predicting fMRI BOLD responses from naturalistic video stimuli (Friends TV show + Movie10 dataset) using Conditional Flow Matching.
Overview
The pipeline decodes brain activity in two sequential stages:
| Stage | Model | Goal |
|---|---|---|
| 1 | MultiSubjectConvLinearEncoder |
Predict a Mean Anchor β a deterministic per-voxel fMRI estimate shared across subjects |
| 2 | CFM (Conditional Flow Matching) |
Learn a per-subject neural vector field that refines the Mean Anchor into a sharper, stochastic fMRI prediction |
The design mirrors the MedARC approach: Stage 1 provides a stable conditional mean $\mu$; Stage 2 integrates a continuous normalizing flow $\phi_t$ conditioned on $\mu$ to sample from the true posterior over voxel activations.
Architecture
Stage 1 β Mean Anchor (medarc_architecture.py)
Features (N, T, D_i) β DepthConv1d + Linear β embed (N, T, d)
β
[Shared Decoder] + [Subject Decoders]
β
fMRI Prediction (N, S, T, V)
MultiSubjectConvLinearEncoder: Projects each feature stream through aLinearConv(depthwise conv + linear) to a shared embedding dimensiond = 192.- Global average pooling across the feature stack combines multi-model features.
- A shared linear decoder and per-subject linear decoders combine additively to predict
Vvoxels for each ofSsubjects. - Trained with MSE loss against ground-truth BOLD.
Stage 2 β Neural Vector Field (matcha_architecture.py)
x1 (B, V, T) β proj_in (V β d) β latent x1 (B, d, T)
mu (B, V, T) β proj_in (V β d) β latent mu (B, d, T)
β
OT-CFM loss (vector field u_t)
Matcha-TTS U-Net estimator
[Conformer / Transformer blocks]
β
latent pred (B, d, T) β proj_out (d β V) β fMRI (B, V, T)
- Latent Bottleneck: fMRI voxels (
V β 1000) are projected down to a dense latent dimension (d = 128) before any convolution, reducing the first-layer parameter count from ~6M to ~98K and preventing gradient collapse. CFMwraps a Matcha-TTS style 1D U-Net (Decoder) with ResNet-1D blocks and Conformer/Transformer attention at each scale.- At inference, noise $z \sim \mathcal{N}(0,I)$ is integrated from $t=0$ to $t=1$ over 25 Euler steps conditioned on $\mu$.
- An auxiliary reconstruction loss (weight
0.1) onproj_in β proj_outtrains the projection pair jointly with the vector field.
Data
| Source | Content | Usage |
|---|---|---|
| Friends (seasons 1β7) | fMRI BOLD + multimodal features | Train (S1), Val (S6, S7) |
| Movie10 | fMRI BOLD + multimodal features | Supplementary val (Figures, Life, Bourne, Wolf) |
Subjects used: 1, 2, 3, 5.
Feature Models
The encoder can ingest intermediate activations from any combination of:
| Key | Model |
|---|---|
internvl3_8b, internvl3_14b |
InternVL3 vision-language model |
qwen-2-5-omni-3b, qwen-2-5-omni-7b |
Qwen2.5-Omni audio-video model |
whisper |
OpenAI Whisper (audio) |
llama_3.2_1b, llama_3.2_3b |
LLaMA 3.2 (text) |
vjepa2 |
V-JEPA 2 (video) |
Active features are set in config.yml under include_features.
File Structure
flow_matching/
βββ config.yml # Main training config (GPU, full data)
βββ debug_config.yml # Fast local debug config (CPU, tiny data)
βββ environment.yml # Conda environment spec
β
βββ src/
β βββ training.py # Two-stage training loop + evaluation
β βββ matcha_architecture.py # CFM + Matcha-TTS U-Net decoder
β βββ medarc_architecture.py # Stage 1 MultiSubjectConvLinearEncoder
β βββ data.py # Algonauts2025 dataset + loaders
β βββ metric.py # Pearson's r voxel-wise scoring
β βββ visualize.py # Loss curve plotting
β βββ inference.py # Standalone inference helper
β
βββ test/
β βββ overfit_test.py # Tiny-batch overfit sanity check for Stage 2
β βββ check_pearson.py # Load checkpoints and plot per-voxel Pearson's r heatmaps
β βββ debug_training.py # End-to-end smoke test
β
βββ experiments/
β βββ *.ipynb # Analysis notebooks (RSA, OOD, brain region plots)
β
βββ Matcha-TTS/ # Vendored Matcha-TTS source (U-Net + solver)
Training
Full training (server)
cd flow_matching
python src/training.py --cfg-path config.yml
Checkpoints are written to output/two_stage_encoding/:
stage1_best.ptβ best Stage 1 model by validation Pearson's rstage2_epoch_N.ptβ Stage 2 snapshot every 5 epochs
Local debug (CPU, tiny model)
python src/training.py --cfg-path debug_config.yml
Evaluation
Pearson's r heatmaps
Loads all available Stage 1 and Stage 2 checkpoints, evaluates on the configured validation set, and saves per-subject per-voxel Pearson's r heatmaps to output/two_stage_encoding/heatmaps/.
python test/check_pearson.py
Output per checkpoint:
Stage 1 Overall Pearson's r: 0.1832
Stage 1 - Sub 1 Mean Pearson's r: 0.1754
Stage 2 Epoch 5 Overall Pearson's r: 0.2110
Stage 2 Epoch 5 - Sub 1 Mean Pearson's r: 0.2043
...
Tiny-batch overfit test
Confirms Stage 2 can memorize a single training batch. If loss does not approach 0 within 500 steps, the architecture cannot learn the task.
python test/overfit_test.py --cfg-path config.yml --subject-idx 0 --steps 500
Key Hyperparameters
| Parameter | Value | Location |
|---|---|---|
| Stage 1 embed dim | 192 | config.yml / stage1.model.embed_dim |
| Stage 1 encoder kernel | 45 | config.yml / stage1.model.encoder_kernel_size |
| Stage 1 LR | 3e-4 | config.yml / stage1.lr |
| Stage 2 latent dim | 128 | config.yml / stage2.latent_dim |
| Stage 2 U-Net channels | [256, 256] | config.yml / stage2.decoder.channels |
| Stage 2 block type | Conformer | config.yml / stage2.decoder.*_block_type |
| Stage 2 LR | 3e-4 | config.yml / stage2.lr |
| Euler integration steps | 25 | config.yml / stage2.n_timesteps |
| CFM Ο_min | 1e-4 | config.yml / stage2.cfm.sigma_min |
Metric
Evaluation uses voxel-wise Pearson's r averaged across subjects:
The scalar reported is the mean over all V voxels and all S subjects.