File size: 3,509 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import torch
from torch.utils.data import Dataset, DataLoader
from src.models.audio_jepa_module import AudioJEPAModule
from src.data.audioset_datamodule import AudioSetDataModule
# Mock Dataset
class MockAudioDataset(Dataset):
def __init__(self, lengths):
self.lengths = lengths
def __len__(self):
return len(self.lengths)
def __getitem__(self, idx):
length = self.lengths[idx]
waveform = torch.randn(1, length)
target = torch.randn(527) # AudioSet classes
return {
"waveform": waveform,
"target": target,
"audio_name": f"audio_{idx}",
"index": idx,
}
def test_variable_length():
# 1. Test Data Loading
lengths = [32000, 48000, 30000, 50000] # Variable lengths
dataset = MockAudioDataset(lengths)
# Use collate_fn from AudioSetDataModule
# Pass parameters manually for testing
def collate_fn(batch):
return AudioSetDataModule.collate_fn(batch, hop_length=1250, patch_time_dim=16)
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
batch = next(iter(loader))
waveforms = batch["waveform"]
print(f"Batch waveforms shape: {waveforms.shape}")
# Check if shape is correct
# Max length is 50000.
# Hop = 1250.
# Max spec len = 50000 // 1250 + 1 = 41.
# Block size = 32.
# Target spec len = ceil(41/32)*32 = 64.
# Target wave len = (64-1)*1250 = 63 * 1250 = 78750.
expected_len = 78750
if waveforms.shape[-1] == expected_len:
print("Padding logic verified!")
else:
print(
f"Padding logic mismatch! Expected {expected_len}, got {waveforms.shape[-1]}"
)
# 2. Test Model Forward
print("Initializing model...")
# Minimal config
net_config = {
"spectrogram": {
"sample_rate": 32000,
"n_fft": 4096,
"win_length": 4096,
"hop_length": 1250,
"n_mels": 128,
"f_min": 0.0,
"f_max": None,
# target_length removed
},
"patch_embed": {
"img_size": (
128,
256,
), # This is just for init, will be ignored/overridden dynamically
"patch_size": (16, 16),
"in_chans": 1,
"embed_dim": 192, # Small dim for speed
},
"masking": {
"input_size": (128, 256),
"patch_size": (16, 16),
"mask_ratio": (0.4, 0.6),
},
"encoder": {
"embed_dim": 192,
"depth": 2,
"num_heads": 3,
"pos_embed_type": "rope",
"img_size": (128, 256),
"patch_size": (16, 16),
},
"predictor": {
"embed_dim": 192,
"depth": 1,
"num_heads": 3,
"pos_embed_type": "rope",
"img_size": (128, 256),
"patch_size": (16, 16),
},
}
model = AudioJEPAModule(optimizer=torch.optim.AdamW, net=net_config)
# Initialize EMA decay manually since we skip Lightning loop
model.current_ema_decay = 0.996
print("Running training_step...")
loss = model.training_step(batch, 0)
print(f"Training step loss: {loss}")
print("Running validation_step...")
val_loss = model.validation_step(batch, 0)
print(f"Validation step loss: {val_loss}")
print("Test passed!")
if __name__ == "__main__":
test_variable_length()
|