Whisper-CLAP-v0.1
Whisper-CLAP is a lightweight audio-text model that projects audio into the sentence embedding space of GTE-Base-en-v1.5.
It acts as a semantic bridge, allowing you to measure the similarity between audio content (speech, sound effects, music) and complex textual descriptions. Unlike traditional CLAP models trained on short tags, this model is optimized for long, descriptive captions generated by Gemini Flash 2.0.
Model Architecture
The model uses a frozen text encoder (GTE) as the target and trains a projection head on top of the Whisper encoder to match the text embeddings.
- Audio Encoder:
openai/whisper-small(Average pooling of the last hidden state). - Projection Head: Linear (768 โ 2048) โ GELU โ Linear (2048 โ 768).
- Text Encoder:
Alibaba-NLP/gte-base-en-v1.5(Frozen during training).
Training Details
- Status: Research Preview / Work-in-Progress (v0.1).
- Data: Trained on a mix of approximately 10 million samples:
- Speech snippets.
- AI-generated music.
- Sound effect datasets (Re-captioned FreeSound, Re-captioned AudioSet).
- Captions: All audio was re-captioned using Gemini Flash 2.0. The captions are highly descriptive, focusing on emotions, speaker attributes, and musical styles (e.g., "A melancholic piano melody with a slow tempo and lots of reverb," "An angry man shouting in a large, echoing hall").
- Hardware: Trained on 2x NVIDIA RTX 3090.
- Duration: ~5 Days
Preliminary Results
Internal Validationset with 400 samples (Music, Speech + SFX) R@1: 35.50% R@5: 61.50%
AudioCaps R@1/5: 1.1/3.8% Clotho R@1/5: 2.9/11.6% ESC50 T1/5: 57.7/82.6%
Usage
This model is useful for:
- RLHF / Ranking: Scoring generated audio (TTS or Music) against a descriptive prompt.
- Retrieval: Searching large audio databases using natural language.
- Zero-Shot Classification: Classifying audio by comparing it to candidate text descriptions.
Inference Example
You can use the following code to load the model and run inference:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import WhisperModel, WhisperFeatureExtractor, AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# --- Model Definition ---
class WhisperClapModel(nn.Module):
def __init__(self, whisper_name="openai/whisper-small"):
super().__init__()
self.audio_encoder = WhisperModel.from_pretrained(whisper_name).encoder
self.projector = nn.Sequential(
nn.Linear(768, 2048),
nn.GELU(),
nn.Linear(2048, 768)
)
def forward(self, input_features):
# Extract features and mean pool
outputs = self.audio_encoder(input_features)
rep = outputs.last_hidden_state.mean(dim=1)
# Project to GTE space and normalize
emb = self.projector(rep)
return F.normalize(emb, p=2, dim=1)
# --- Setup ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
WHISPER_NAME = "openai/whisper-small"
TEXT_MODEL_NAME = "Alibaba-NLP/gte-base-en-v1.5"
# 1. Load Model
model = WhisperClapModel(WHISPER_NAME).to(DEVICE)
# Download and load weights
checkpoint_path = hf_hub_download(repo_id="laion/whisper-clap-version-0.1", filename="model.safetensors")
state_dict = load_file(checkpoint_path)
# Clean keys (remove 'model.' prefix if present from training)
clean_state = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(clean_state, strict=False)
model.eval()
# 2. Load Processors
feature_extractor = WhisperFeatureExtractor.from_pretrained(WHISPER_NAME)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True).to(DEVICE).eval()
# 3. Helper Functions
def encode_audio(audio_path):
wav, sr = torchaudio.load(audio_path)
# Resample to 16k
if sr != 16000:
wav = torchaudio.functional.resample(wav, sr, 16000)
# Mono
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
# Pad/Trim to 30s
target_len = 16000 * 30
if wav.shape[1] > target_len:
wav = wav[:, :target_len]
elif wav.shape[1] < target_len:
wav = F.pad(wav, (0, target_len - wav.shape[1]))
inputs = feature_extractor(wav.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
return model(inputs.input_features.to(DEVICE))
def encode_text(text):
inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = text_model(**inputs)
# GTE pooling
mask = inputs.attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1)
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
return F.normalize(sum_embeddings / sum_mask, p=2, dim=1)
# --- Run ---
# audio_emb = encode_audio("my_audio.wav")
# text_emb = encode_text("A calm, soothing voice explaining a concept.")
# similarity = torch.matmul(audio_emb, text_emb.t()).item()
# print(f"Similarity: {similarity:.4f}")
Training
If you want to fine-tune this model or train it from scratch on your own data, here is a basic training snippet.
Data Requirement: This script expects a folder containing pairs of files with the same basename:
- Audio file (e.g.,
recording_001.mp3or.wav) - Metadata file (e.g.,
recording_001.json) containing a"caption"field.
import os
import json
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor, AutoTokenizer, AutoModel
# --- Configuration ---
DATA_FOLDER = "./my_training_data" # Path to your data
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- 1. Dataset Loader ---
class AudioTextDataset(Dataset):
def __init__(self, folder_path):
self.samples = []
# Find all audio files (mp3 and wav)
audio_files = glob.glob(os.path.join(folder_path, "*.[mM][pP]3")) + \
glob.glob(os.path.join(folder_path, "*.[wW][aA][vV]"))
for audio_path in audio_files:
# Construct expected JSON path (same basename)
base_path = os.path.splitext(audio_path)[0]
json_path = base_path + ".json"
if os.path.exists(json_path):
try:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Ensure your JSON has a 'caption' key, or adjust this line
caption = data.get("caption")
if caption:
self.samples.append((audio_path, caption))
except Exception as e:
print(f"Error loading {json_path}: {e}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
audio_path, text = self.samples[idx]
# Load and process Audio
wav, sr = torchaudio.load(audio_path)
if sr != 16000:
wav = torchaudio.functional.resample(wav, sr, 16000)
# Mix to mono if necessary
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
# Random crop or pad to 30s (simple strategy for training)
target_len = 16000 * 30
if wav.shape[1] > target_len:
# Random crop
start = torch.randint(0, wav.shape[1] - target_len, (1,)).item()
wav = wav[:, start:start+target_len]
elif wav.shape[1] < target_len:
wav = F.pad(wav, (0, target_len - wav.shape[1]))
return {"audio": wav.squeeze(), "text": text}
# --- 2. Training Model Wrapper ---
class WhisperClapTrainer(nn.Module):
def __init__(self):
super().__init__()
# Audio Encoder (Trainable)
self.whisper = WhisperModel.from_pretrained("openai/whisper-small")
self.audio_encoder = self.whisper.encoder
# Projection Head (Trainable)
self.projector = nn.Sequential(
nn.Linear(768, 2048),
nn.GELU(),
nn.Linear(2048, 768)
)
# Text Encoder (Frozen - we want to align audio to the fixed GTE space)
self.text_model = AutoModel.from_pretrained("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True)
for p in self.text_model.parameters():
p.requires_grad = False
def forward(self, input_features, input_ids, attention_mask):
# 1. Encode Audio
audio_out = self.audio_encoder(input_features)
audio_rep = audio_out.last_hidden_state.mean(dim=1)
audio_emb = F.normalize(self.projector(audio_rep), p=2, dim=1)
# 2. Encode Text (with Gradient disabled)
with torch.no_grad():
text_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
# GTE specific pooling logic
mask_expanded = attention_mask.unsqueeze(-1).expand(text_out.last_hidden_state.size()).float()
sum_embeddings = torch.sum(text_out.last_hidden_state * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
text_emb = F.normalize(sum_embeddings / sum_mask, p=2, dim=1)
# 3. Compute Loss (MSE or Cosine Embedding Loss)
# We maximize cosine similarity (minimize 1 - cos_sim)
cos_sim = F.cosine_similarity(audio_emb, text_emb)
loss = 1.0 - cos_sim.mean()
return loss
# --- 3. Main Training Loop ---
def train():
# Tools
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-base-en-v1.5")
# Data
dataset = AudioTextDataset(DATA_FOLDER)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
# Model
model = WhisperClapTrainer().to(DEVICE)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
model.train()
print(f"Starting training on {len(dataset)} pairs...")
for epoch in range(EPOCHS):
total_loss = 0
for batch in dataloader:
optimizer.zero_grad()
# Prepare Inputs
audio_inputs = feature_extractor(batch["audio"].numpy(), sampling_rate=16000, return_tensors="pt")
text_inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=512, return_tensors="pt")
# Move to device
input_features = audio_inputs.input_features.to(DEVICE)
input_ids = text_inputs.input_ids.to(DEVICE)
attn_mask = text_inputs.attention_mask.to(DEVICE)
# Forward & Backward
loss = model(input_features, input_ids, attn_mask)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{EPOCHS} - Avg Loss: {total_loss / len(dataloader):.4f}")
# Save final model
torch.save(model.state_dict(), "whisper_clap_finetuned.pt")
print("Training complete. Model saved.")
if __name__ == "__main__":
train()