"""Fast audio captioning: CLAP tags + Silero VAD + faster-whisper lyrics. Provides mood/genre/instrument tagging via CLAP zero-shot classification, speech detection via Silero VAD, and lyrics extraction via faster-whisper. All models run on CPU. Total: ~3-5 min per file. Usage: from caption_fast import caption_audio result = caption_audio("song.mp3") # {"caption": "Pop, Energetic, Guitar, Melodic, Upbeat", # "lyrics": "[Verse]\nSome lyrics here...", # "bpm": 120, "key": "C major", "signature": "4/4", # "tags": ["Pop", "Energetic", "Guitar", ...]} """ from __future__ import annotations import json import logging import os from pathlib import Path from typing import Dict, List, Optional logger = logging.getLogger(__name__) # Tag list for CLAP zero-shot classification (from clap-interrogator) TAGS = [ "Fast", "Slow", "Upbeat", "Downbeat", "Moderate", "Happy", "Sad", "Energetic", "Relaxed", "Melancholic", "Uplifting", "Aggressive", "Peaceful", "Romantic", "Dark", "Light", "Mysterious", "Dreamy", "Somber", "Hopeful", "Gloomy", "Cheerful", "Reflective", "Nostalgic", "Tense", "Calm", "Piano", "Guitar", "Violin", "Drums", "Bass", "Synthesizer", "Saxophone", "Trumpet", "Flute", "Cello", "Clarinet", "Harp", "Percussion", "Organ", "Accordion", "Electronic", "Acoustic", "Electric Guitar", "Acoustic Guitar", "Synth Pad", "Keyboards", "Rock", "Pop", "Jazz", "Classical", "Electronic", "Folk", "Hip-Hop", "Blues", "Ambient", "Country", "Reggae", "Funk", "Soul", "Metal", "Dance", "Disco", "House", "Techno", "Trance", "Soundtrack", "World", "Indie", "Alternative", "R&B", "EDM", "Chillwave", "Dubstep", "Lo-fi Hip-Hop", "Drum and Bass", "Jazz Fusion", "Neo-Soul", "Trap", "K-Pop", "J-Pop", "Reggaeton", "Punk", "Grunge", "Bright", "Warm", "Smooth", "Distorted", "Clean", "Lo-fi", "Layered", "Minimalist", "Cinematic", "Atmospheric", "Ethereal", "Groovy", "Rhythmic", "Melodic", "Harmonic", "Live", "Studio", "Instrumental", ] _clap_model = None _clap_processor = None _whisper_model = None _vad_model = None def _load_clap(): global _clap_model, _clap_processor if _clap_model is not None: return _clap_model, _clap_processor from transformers import ClapModel, ClapProcessor logger.info("[CLAP] Loading laion/larger_clap_music...") _clap_processor = ClapProcessor.from_pretrained("laion/larger_clap_music") _clap_model = ClapModel.from_pretrained("laion/larger_clap_music") _clap_model.eval() logger.info("[CLAP] Ready (~780MB)") return _clap_model, _clap_processor def _load_whisper(): global _whisper_model if _whisper_model is not None: return _whisper_model from faster_whisper import WhisperModel logger.info("[Whisper] Loading large-v3-turbo (int8, CPU)...") _whisper_model = WhisperModel( "large-v3-turbo", device="cpu", compute_type="int8", ) logger.info("[Whisper] Ready (~1.5GB)") return _whisper_model def _load_vad(): global _vad_model if _vad_model is not None: return _vad_model import torch logger.info("[VAD] Loading Silero VAD...") _vad_model, _vad_utils = torch.hub.load( repo_or_dir='snakers4/silero-vad', model='silero_vad', onnx=True, trust_repo=True, ) logger.info("[VAD] Ready (~2MB)") return _vad_model def unload_caption_models(): """Free all captioning models from memory.""" global _clap_model, _clap_processor, _whisper_model, _vad_model import gc _clap_model = None _clap_processor = None _whisper_model = None _vad_model = None gc.collect() logger.info("[Caption] All models unloaded") def tag_audio(audio_path: str, top_n: int = 10) -> List[str]: """Get top-N CLAP tags for an audio file.""" import librosa import torch model, processor = _load_clap() audio, sr = librosa.load(audio_path, sr=48000, mono=True) inputs = processor( text=TAGS, audio=[audio], sampling_rate=48000, return_tensors="pt", padding=True, ) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_audio.softmax(dim=-1) top_probs, top_indices = probs.topk(top_n, dim=1) return [TAGS[i] for i in top_indices[0].tolist()] def detect_speech(audio_path: str, threshold: float = 5.0) -> bool: """Check if audio contains speech using Silero VAD. Returns True if speech detected for more than `threshold` seconds. """ import torch import librosa vad = _load_vad() y, sr = librosa.load(audio_path, sr=16000, mono=True) wav = torch.from_numpy(y).unsqueeze(0) speech_timestamps = [] window_size = 512 for i in range(0, wav.shape[1], window_size): chunk = wav[0, i:i + window_size] if len(chunk) < window_size: break prob = vad(chunk, 16000).item() if prob > 0.5: speech_timestamps.append(i / 16000) speech_duration = len(speech_timestamps) * (window_size / 16000) logger.info("[VAD] Speech: %.1fs detected in %s", speech_duration, os.path.basename(audio_path)) return speech_duration > threshold def transcribe_lyrics(audio_path: str) -> str: """Extract lyrics from audio using faster-whisper.""" model = _load_whisper() segments, info = model.transcribe( audio_path, language=None, beam_size=5, vad_filter=True, ) lines = [] for segment in segments: text = segment.text.strip() if text: lines.append(text) lyrics = "\n".join(lines) if not lyrics.strip(): return "[Instrumental]" logger.info("[Whisper] Transcribed %d lines (lang=%s, prob=%.2f)", len(lines), info.language, info.language_probability) return lyrics def get_bpm_key(audio_path: str) -> Dict[str, str]: """Get BPM and key via librosa.""" import librosa import numpy as np y, sr = librosa.load(audio_path, sr=None, mono=True) tempo, _ = librosa.beat.beat_track(y=y, sr=sr) bpm = int(round(float(tempo.item() if hasattr(tempo, 'item') else tempo))) chroma = librosa.feature.chroma_cens(y=y, sr=sr) chroma_avg = np.mean(chroma, axis=1) keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88]) minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]) best_corr = -1 best_key = "C major" for i in range(12): maj_corr = float(np.corrcoef(np.roll(major_profile, i), chroma_avg)[0, 1]) min_corr = float(np.corrcoef(np.roll(minor_profile, i), chroma_avg)[0, 1]) if maj_corr > best_corr: best_corr = maj_corr best_key = f"{keys[i]} major" if min_corr > best_corr: best_corr = min_corr best_key = f"{keys[i]} minor" return {"bpm": str(bpm), "key": best_key, "signature": "4/4"} def caption_audio( audio_path: str, top_n: int = 10, extract_lyrics: bool = True, speech_threshold: float = 5.0, ) -> Dict[str, str]: """Full fast captioning pipeline for one audio file. Returns dict with: caption, lyrics, bpm, key, signature, tags """ fname = os.path.basename(audio_path) logger.info("[Caption] Processing %s...", fname) # 1. CLAP tags (mood, genre, instruments) tags = tag_audio(audio_path, top_n=top_n) caption = ", ".join(tags) logger.info("[Caption] %s: tags=%s", fname, caption) # 2. BPM + key via librosa bpm_key = get_bpm_key(audio_path) logger.info("[Caption] %s: BPM=%s, key=%s", fname, bpm_key["bpm"], bpm_key["key"]) # 3. Speech detection + lyrics lyrics = "[Instrumental]" if extract_lyrics: has_speech = detect_speech(audio_path, threshold=speech_threshold) if has_speech: logger.info("[Caption] %s: speech detected, transcribing lyrics...", fname) lyrics = transcribe_lyrics(audio_path) else: logger.info("[Caption] %s: no speech, marking instrumental", fname) return { "caption": caption, "lyrics": lyrics, "bpm": bpm_key["bpm"], "key": bpm_key["key"], "signature": bpm_key["signature"], "tags": tags, }