File size: 3,509 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from torch.utils.data import Dataset, DataLoader
from src.models.audio_jepa_module import AudioJEPAModule
from src.data.audioset_datamodule import AudioSetDataModule


# Mock Dataset
class MockAudioDataset(Dataset):
    def __init__(self, lengths):
        self.lengths = lengths

    def __len__(self):
        return len(self.lengths)

    def __getitem__(self, idx):
        length = self.lengths[idx]
        waveform = torch.randn(1, length)
        target = torch.randn(527)  # AudioSet classes
        return {
            "waveform": waveform,
            "target": target,
            "audio_name": f"audio_{idx}",
            "index": idx,
        }


def test_variable_length():
    # 1. Test Data Loading
    lengths = [32000, 48000, 30000, 50000]  # Variable lengths
    dataset = MockAudioDataset(lengths)

    # Use collate_fn from AudioSetDataModule
    # Pass parameters manually for testing
    def collate_fn(batch):
        return AudioSetDataModule.collate_fn(batch, hop_length=1250, patch_time_dim=16)

    loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)

    batch = next(iter(loader))
    waveforms = batch["waveform"]
    print(f"Batch waveforms shape: {waveforms.shape}")

    # Check if shape is correct
    # Max length is 50000.
    # Hop = 1250.
    # Max spec len = 50000 // 1250 + 1 = 41.
    # Block size = 32.
    # Target spec len = ceil(41/32)*32 = 64.
    # Target wave len = (64-1)*1250 = 63 * 1250 = 78750.

    expected_len = 78750
    if waveforms.shape[-1] == expected_len:
        print("Padding logic verified!")
    else:
        print(
            f"Padding logic mismatch! Expected {expected_len}, got {waveforms.shape[-1]}"
        )

    # 2. Test Model Forward
    print("Initializing model...")
    # Minimal config
    net_config = {
        "spectrogram": {
            "sample_rate": 32000,
            "n_fft": 4096,
            "win_length": 4096,
            "hop_length": 1250,
            "n_mels": 128,
            "f_min": 0.0,
            "f_max": None,
            # target_length removed
        },
        "patch_embed": {
            "img_size": (
                128,
                256,
            ),  # This is just for init, will be ignored/overridden dynamically
            "patch_size": (16, 16),
            "in_chans": 1,
            "embed_dim": 192,  # Small dim for speed
        },
        "masking": {
            "input_size": (128, 256),
            "patch_size": (16, 16),
            "mask_ratio": (0.4, 0.6),
        },
        "encoder": {
            "embed_dim": 192,
            "depth": 2,
            "num_heads": 3,
            "pos_embed_type": "rope",
            "img_size": (128, 256),
            "patch_size": (16, 16),
        },
        "predictor": {
            "embed_dim": 192,
            "depth": 1,
            "num_heads": 3,
            "pos_embed_type": "rope",
            "img_size": (128, 256),
            "patch_size": (16, 16),
        },
    }

    model = AudioJEPAModule(optimizer=torch.optim.AdamW, net=net_config)

    # Initialize EMA decay manually since we skip Lightning loop
    model.current_ema_decay = 0.996

    print("Running training_step...")
    loss = model.training_step(batch, 0)
    print(f"Training step loss: {loss}")

    print("Running validation_step...")
    val_loss = model.validation_step(batch, 0)
    print(f"Validation step loss: {val_loss}")

    print("Test passed!")


if __name__ == "__main__":
    test_variable_length()