Whisper-CLAP-v0.1

Whisper-CLAP is a lightweight audio-text model that projects audio into the sentence embedding space of GTE-Base-en-v1.5.

It acts as a semantic bridge, allowing you to measure the similarity between audio content (speech, sound effects, music) and complex textual descriptions. Unlike traditional CLAP models trained on short tags, this model is optimized for long, descriptive captions generated by Gemini Flash 2.0.

Model Architecture

The model uses a frozen text encoder (GTE) as the target and trains a projection head on top of the Whisper encoder to match the text embeddings.

  • Audio Encoder: openai/whisper-small (Average pooling of the last hidden state).
  • Projection Head: Linear (768 โ†’ 2048) โ†’ GELU โ†’ Linear (2048 โ†’ 768).
  • Text Encoder: Alibaba-NLP/gte-base-en-v1.5 (Frozen during training).

Training Details

  • Status: Research Preview / Work-in-Progress (v0.1).
  • Data: Trained on a mix of approximately 10 million samples:
    • Speech snippets.
    • AI-generated music.
    • Sound effect datasets (Re-captioned FreeSound, Re-captioned AudioSet).
  • Captions: All audio was re-captioned using Gemini Flash 2.0. The captions are highly descriptive, focusing on emotions, speaker attributes, and musical styles (e.g., "A melancholic piano melody with a slow tempo and lots of reverb," "An angry man shouting in a large, echoing hall").
  • Hardware: Trained on 2x NVIDIA RTX 3090.
  • Duration: ~5 Days

Preliminary Results

Internal Validationset with 400 samples (Music, Speech + SFX) R@1: 35.50% R@5: 61.50%

AudioCaps R@1/5: 1.1/3.8% Clotho R@1/5: 2.9/11.6% ESC50 T1/5: 57.7/82.6%

Usage

This model is useful for:

  1. RLHF / Ranking: Scoring generated audio (TTS or Music) against a descriptive prompt.
  2. Retrieval: Searching large audio databases using natural language.
  3. Zero-Shot Classification: Classifying audio by comparing it to candidate text descriptions.

Inference Example

You can use the following code to load the model and run inference:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import WhisperModel, WhisperFeatureExtractor, AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# --- Model Definition ---
class WhisperClapModel(nn.Module):
    def __init__(self, whisper_name="openai/whisper-small"):
        super().__init__()
        self.audio_encoder = WhisperModel.from_pretrained(whisper_name).encoder
        self.projector = nn.Sequential(
            nn.Linear(768, 2048),
            nn.GELU(),
            nn.Linear(2048, 768)
        )

    def forward(self, input_features):
        # Extract features and mean pool
        outputs = self.audio_encoder(input_features)
        rep = outputs.last_hidden_state.mean(dim=1) 
        # Project to GTE space and normalize
        emb = self.projector(rep)
        return F.normalize(emb, p=2, dim=1)

# --- Setup ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
WHISPER_NAME = "openai/whisper-small"
TEXT_MODEL_NAME = "Alibaba-NLP/gte-base-en-v1.5"

# 1. Load Model
model = WhisperClapModel(WHISPER_NAME).to(DEVICE)

# Download and load weights
checkpoint_path = hf_hub_download(repo_id="laion/whisper-clap-version-0.1", filename="model.safetensors")
state_dict = load_file(checkpoint_path)

# Clean keys (remove 'model.' prefix if present from training)
clean_state = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(clean_state, strict=False)
model.eval()

# 2. Load Processors
feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_NAME)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True).to(DEVICE).eval()

# 3. Helper Functions
def encode_audio(audio_path):
    wav, sr = torchaudio.load(audio_path)
    # Resample to 16k
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    # Mono
    if wav.shape[0] > 1:
        wav = torch.mean(wav, dim=0, keepdim=True)
    # Pad/Trim to 30s
    target_len = 16000 * 30
    if wav.shape[1] > target_len:
        wav = wav[:, :target_len]
    elif wav.shape[1] < target_len:
        wav = F.pad(wav, (0, target_len - wav.shape[1]))
        
    inputs = feature_extractor(wav.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
    with torch.no_grad():
        return model(inputs.input_features.to(DEVICE))

def encode_text(text):
    inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = text_model(**inputs)
        # GTE pooling
        mask = inputs.attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
        sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1)
        sum_mask = torch.clamp(mask.sum(1), min=1e-9)
        return F.normalize(sum_embeddings / sum_mask, p=2, dim=1)

# --- Run ---
# audio_emb = encode_audio("my_audio.wav")
# text_emb = encode_text("A calm, soothing voice explaining a concept.")
# similarity = torch.matmul(audio_emb, text_emb.t()).item()
# print(f"Similarity: {similarity:.4f}")

Training

If you want to fine-tune this model or train it from scratch on your own data, here is a basic training snippet.

Data Requirement: This script expects a folder containing pairs of files with the same basename:

  • Audio file (e.g., recording_001.mp3 or .wav)
  • Metadata file (e.g., recording_001.json) containing a "caption" field.
import os
import json
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor, AutoTokenizer, AutoModel

# --- Configuration ---
DATA_FOLDER = "./my_training_data"  # Path to your data
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- 1. Dataset Loader ---
class AudioTextDataset(Dataset):
    def __init__(self, folder_path):
        self.samples = []
        # Find all audio files (mp3 and wav)
        audio_files = glob.glob(os.path.join(folder_path, "*.[mM][pP]3")) + \
                      glob.glob(os.path.join(folder_path, "*.[wW][aA][vV]"))
        
        for audio_path in audio_files:
            # Construct expected JSON path (same basename)
            base_path = os.path.splitext(audio_path)[0]
            json_path = base_path + ".json"
            
            if os.path.exists(json_path):
                try:
                    with open(json_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                        # Ensure your JSON has a 'caption' key, or adjust this line
                        caption = data.get("caption") 
                        if caption:
                            self.samples.append((audio_path, caption))
                except Exception as e:
                    print(f"Error loading {json_path}: {e}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        audio_path, text = self.samples[idx]
        
        # Load and process Audio
        wav, sr = torchaudio.load(audio_path)
        if sr != 16000:
            wav = torchaudio.functional.resample(wav, sr, 16000)
        # Mix to mono if necessary
        if wav.shape[0] > 1:
            wav = torch.mean(wav, dim=0, keepdim=True)
            
        # Random crop or pad to 30s (simple strategy for training)
        target_len = 16000 * 30
        if wav.shape[1] > target_len:
            # Random crop
            start = torch.randint(0, wav.shape[1] - target_len, (1,)).item()
            wav = wav[:, start:start+target_len]
        elif wav.shape[1] < target_len:
            wav = F.pad(wav, (0, target_len - wav.shape[1]))

        return {"audio": wav.squeeze(), "text": text}

# --- 2. Training Model Wrapper ---
class WhisperClapTrainer(nn.Module):
    def __init__(self):
        super().__init__()
        # Audio Encoder (Trainable)
        self.whisper = WhisperModel.from_pretrained("openai/whisper-small")
        self.audio_encoder = self.whisper.encoder
        
        # Projection Head (Trainable)
        self.projector = nn.Sequential(
            nn.Linear(768, 2048),
            nn.GELU(),
            nn.Linear(2048, 768)
        )
        
        # Text Encoder (Frozen - we want to align audio to the fixed GTE space)
        self.text_model = AutoModel.from_pretrained("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
        for p in self.text_model.parameters():
            p.requires_grad = False

    def forward(self, input_features, input_ids, attention_mask):
        # 1. Encode Audio
        audio_out = self.audio_encoder(input_features)
        audio_rep = audio_out.last_hidden_state.mean(dim=1)
        audio_emb = F.normalize(self.projector(audio_rep), p=2, dim=1)

        # 2. Encode Text (with Gradient disabled)
        with torch.no_grad():
            text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
            # GTE specific pooling logic
            mask_expanded = attention_mask.unsqueeze(-1).expand(text_out.last_hidden_state.size()).float()
            sum_embeddings = torch.sum(text_out.last_hidden_state * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            text_emb = F.normalize(sum_embeddings / sum_mask, p=2, dim=1)

        # 3. Compute Loss (MSE or Cosine Embedding Loss)
        # We maximize cosine similarity (minimize 1 - cos_sim)
        cos_sim = F.cosine_similarity(audio_emb, text_emb)
        loss = 1.0 - cos_sim.mean()
        
        return loss

# --- 3. Main Training Loop ---
def train():
    # Tools
    feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
    tokenizer = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-base-en-v1.5")
    
    # Data
    dataset = AudioTextDataset(DATA_FOLDER)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    
    # Model
    model = WhisperClapTrainer().to(DEVICE)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)

    model.train()
    print(f"Starting training on {len(dataset)} pairs...")

    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()
            
            # Prepare Inputs
            audio_inputs = feature_extractor(batch["audio"].numpy(), sampling_rate=16000, return_tensors="pt")
            text_inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=512, return_tensors="pt")
            
            # Move to device
            input_features = audio_inputs.input_features.to(DEVICE)
            input_ids = text_inputs.input_ids.to(DEVICE)
            attn_mask = text_inputs.attention_mask.to(DEVICE)

            # Forward & Backward
            loss = model(input_features, input_ids, attn_mask)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
        print(f"Epoch {epoch+1}/{EPOCHS} - Avg Loss: {total_loss / len(dataloader):.4f}")

    # Save final model
    torch.save(model.state_dict(), "whisper_clap_finetuned.pt")
    print("Training complete. Model saved.")

if __name__ == "__main__":
    train()
Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
0.2B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support