Spaces:
Running
Running
| """ACE-Step 1.5 XL (CPU) - Gradio frontend + CLI for ace-server GGUF inference""" | |
| import os | |
| import sys | |
| import time | |
| import json | |
| import argparse | |
| import base64 | |
| import tempfile | |
| import subprocess | |
| import shutil | |
| import string | |
| import random | |
| import requests | |
| import logging | |
| import threading | |
| from train_engine import ( | |
| preprocess_audio, | |
| train_lora_generator, | |
| cancel_training, | |
| _training_cancel, | |
| get_trained_loras as _get_trained_loras_engine, | |
| MAX_TRAINING_TIME, | |
| ) | |
| logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stdout) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Configurable limits (edit here, not buried in code) | |
| # --------------------------------------------------------------------------- | |
| MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min) | |
| # MAX_TRAINING_TIME is imported from train_engine (single source of truth) | |
| MAX_AUDIO_FILES = 50 # max number of training audio files per run | |
| # --------------------------------------------------------------------------- | |
| # Paths & constants | |
| # --------------------------------------------------------------------------- | |
| ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085") | |
| OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Clean up old inference temp files (older than 1 hour) at startup | |
| _CLEANUP_MAX_AGE = 3600 # seconds | |
| try: | |
| _now = time.time() | |
| for _fname in os.listdir(OUTPUT_DIR): | |
| if _fname.lower().endswith((".wav", ".mp3")): | |
| _fpath = os.path.join(OUTPUT_DIR, _fname) | |
| try: | |
| if os.path.isfile(_fpath) and (_now - os.path.getmtime(_fpath)) > _CLEANUP_MAX_AGE: | |
| os.remove(_fpath) | |
| except OSError: | |
| pass | |
| except Exception: | |
| pass | |
| ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints") | |
| ACE_SOURCE_DIR = "/app/ace-step-source" | |
| ACE_HF_MODEL = "ACE-Step/Ace-Step1.5" | |
| ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters") | |
| MODELS_DIR = os.environ.get("ACE_MODELS_DIR", "/app/models") | |
| ACE_SERVER_BIN = "/app/ace-server" | |
| # Detect if running on HF Space (ace-server available) vs locally (PyTorch only) | |
| _is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None | |
| _training_lock = threading.Lock() | |
| # HF repo for on-demand GGUF downloads | |
| GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF" | |
| # --------------------------------------------------------------------------- | |
| # ace-server helpers | |
| # --------------------------------------------------------------------------- | |
| def _server_ok(): | |
| try: | |
| return requests.get(f"{ACE_SERVER}/health", timeout=5).status_code == 200 | |
| except Exception: | |
| return False | |
| def _get_props(): | |
| """Fetch server properties (models, adapters).""" | |
| try: | |
| r = requests.get(f"{ACE_SERVER}/props", timeout=10) | |
| if r.status_code == 200: | |
| return r.json() | |
| except Exception: | |
| pass | |
| return {} | |
| def _poll_job(job_id, timeout=600, progress_cb=None, cancel_check=None): | |
| """Poll a job until done/error/timeout/cancelled. Returns (status, elapsed, data).""" | |
| t0 = time.time() | |
| while time.time() - t0 < timeout: | |
| if cancel_check and cancel_check(): | |
| return "cancelled", time.time() - t0, None | |
| try: | |
| r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=5) | |
| data = r.json() | |
| status = data.get("status", "unknown") | |
| if progress_cb: | |
| progress_cb(status, data) | |
| if status in ("done", "error"): | |
| return status, time.time() - t0, data | |
| except Exception: | |
| pass | |
| time.sleep(1) | |
| return "timeout", time.time() - t0, None | |
| def _fetch_result(job_id, timeout=60): | |
| """Fetch result bytes/json for a completed job.""" | |
| r = requests.get( | |
| f"{ACE_SERVER}/job", | |
| params={"id": job_id, "result": 1}, | |
| timeout=timeout, | |
| ) | |
| return r | |
| def _caption_via_understand(audio_path, timeout=600, cancel_check=None): | |
| """Call ace-server /understand for a rich caption. Returns dict or None.""" | |
| fname = os.path.basename(audio_path) | |
| try: | |
| with open(audio_path, "rb") as f: | |
| r = requests.post( | |
| f"{ACE_SERVER}/understand", | |
| files={"audio": (fname, f, "audio/mpeg")}, | |
| timeout=30, | |
| ) | |
| if r.status_code != 200: | |
| logger.warning("[Caption] %s: /understand %d: %s", fname, r.status_code, r.text[:200]) | |
| return None | |
| job_id = r.json().get("id") | |
| if not job_id: | |
| return None | |
| except Exception as exc: | |
| logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc) | |
| return None | |
| status, elapsed, poll_data = _poll_job(job_id, timeout=timeout, cancel_check=cancel_check) | |
| if status != "done": | |
| logger.warning("[Caption] %s: /understand -> %s (%.0fs)", fname, status, elapsed) | |
| return None | |
| # Fetch result — /understand returns multipart/mixed (JSON + latents) | |
| try: | |
| r = _fetch_result(job_id, timeout=120) | |
| if r.status_code != 200: | |
| logger.warning("[Caption] %s: result fetch HTTP %d", fname, r.status_code) | |
| return None | |
| content_type = r.headers.get("Content-Type", "") | |
| # multipart/mixed: extract JSON part (caption metadata) | |
| if "multipart" in content_type: | |
| boundary = None | |
| for part in content_type.split(";"): | |
| part = part.strip() | |
| if part.startswith("boundary="): | |
| boundary = part.split("=", 1)[1].strip('"') | |
| if boundary: | |
| import re | |
| parts = r.content.split(f"--{boundary}".encode()) | |
| for part in parts: | |
| if b"application/json" in part: | |
| json_start = part.find(b"{") | |
| json_end = part.rfind(b"}") + 1 | |
| if json_start >= 0 and json_end > json_start: | |
| data = json.loads(part[json_start:json_end]) | |
| if isinstance(data, dict) and data.get("caption"): | |
| logger.info("[Caption] %s: got caption (%d chars)", | |
| fname, len(data["caption"])) | |
| return data | |
| # Plain JSON fallback | |
| if r.text.strip(): | |
| data = r.json() | |
| if isinstance(data, dict) and data.get("caption"): | |
| return data | |
| except Exception as exc: | |
| logger.warning("[Caption] %s: result parse failed: %s", fname, exc) | |
| logger.warning("[Caption] %s: no caption extracted from result", fname) | |
| return None | |
| def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format, | |
| adapter=None, lm_model=None, progress_cb=None): | |
| """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises.""" | |
| t0 = time.time() | |
| # -- Build LM request -- | |
| req = {"caption": caption or "upbeat electronic dance music"} | |
| req["lyrics"] = lyrics if lyrics and lyrics.strip() else "[Instrumental]" | |
| try: | |
| if bpm and int(float(bpm)) > 0: | |
| req["bpm"] = int(float(bpm)) | |
| if duration and float(duration) > 0: | |
| req["duration"] = min(float(duration), 300) | |
| if seed is not None and int(float(seed)) >= 0: | |
| req["seed"] = int(float(seed)) | |
| if steps and int(float(steps)) > 0: | |
| req["inference_steps"] = int(float(steps)) | |
| except (ValueError, TypeError): | |
| pass | |
| if adapter: | |
| req["adapter"] = adapter | |
| if lm_model: | |
| req["model"] = lm_model | |
| fmt = output_format if output_format in ("wav", "mp3") else "mp3" | |
| synth_fmt = "wav16" if fmt == "wav" else "mp3" | |
| suffix = f".{fmt}" | |
| # -- LM phase -- | |
| if progress_cb: | |
| progress_cb("lm_submit", None) | |
| r = requests.post(f"{ACE_SERVER}/lm", json=req, timeout=30) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"LM submit failed: {r.status_code} {r.text}") | |
| lm_job_id = r.json().get("id") | |
| if progress_cb: | |
| progress_cb("lm_poll", {"job_id": lm_job_id}) | |
| lm_status, lm_elapsed, _ = _poll_job(lm_job_id, timeout=900) | |
| if lm_status != "done": | |
| raise RuntimeError(f"LM {lm_status} after {lm_elapsed:.0f}s") | |
| # Fetch LM result | |
| r = _fetch_result(lm_job_id) | |
| lm_results = r.json() | |
| if not isinstance(lm_results, list) or len(lm_results) == 0: | |
| raise RuntimeError(f"LM returned no results: {lm_results}") | |
| synth_request = lm_results[0] | |
| # -- Synth phase -- | |
| synth_request["output_format"] = synth_fmt | |
| if adapter: | |
| synth_request["adapter"] = adapter | |
| synth_request["synth_model"] = "acestep-v15-turbo-Q4_K_M.gguf" | |
| if progress_cb: | |
| progress_cb("synth_submit", None) | |
| r = requests.post(f"{ACE_SERVER}/synth", json=synth_request, timeout=30) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"Synth submit failed: {r.status_code} {r.text}") | |
| synth_job_id = r.json().get("id") | |
| if progress_cb: | |
| progress_cb("synth_poll", {"job_id": synth_job_id}) | |
| synth_status, synth_elapsed, _ = _poll_job(synth_job_id, timeout=600) | |
| if synth_status != "done": | |
| raise RuntimeError(f"Synth {synth_status} after {synth_elapsed:.0f}s") | |
| # Fetch audio | |
| if progress_cb: | |
| progress_cb("fetch", None) | |
| r = _fetch_result(synth_job_id, timeout=60) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"Audio fetch failed: {r.status_code}") | |
| tmp = tempfile.NamedTemporaryFile(suffix=suffix, dir=OUTPUT_DIR, delete=False) | |
| tmp.write(r.content) | |
| tmp.close() | |
| elapsed = time.time() - t0 | |
| msg = f"Done in {elapsed:.0f}s | {duration}s audio, {steps} steps, {fmt}" | |
| return tmp.name, msg | |
| # --------------------------------------------------------------------------- | |
| # LM model scanning & on-demand download | |
| # --------------------------------------------------------------------------- | |
| DEFAULT_LM = "acestep-5Hz-lm-1.7B-Q8_0.gguf" | |
| AVAILABLE_LM_MODELS = [ | |
| "acestep-5Hz-lm-1.7B-Q8_0.gguf", | |
| "acestep-5Hz-lm-0.6B-Q8_0.gguf", | |
| "acestep-5Hz-lm-4B-Q5_K_M.gguf", | |
| ] | |
| def _scan_lm_models(): | |
| """Return LM model choices. Installed shown as-is, others need download.""" | |
| installed = set() | |
| if os.path.isdir(MODELS_DIR): | |
| for f in os.listdir(MODELS_DIR): | |
| if "-lm-" in f and f.endswith(".gguf"): | |
| installed.add(f) | |
| choices = [] | |
| for m in AVAILABLE_LM_MODELS: | |
| if m in installed: | |
| choices.append(m) | |
| else: | |
| choices.append(f"{m} [not installed]") | |
| return choices | |
| def _download_lm_model(filename): | |
| """Download a GGUF LM model from HF if not already present.""" | |
| dest = os.path.join(MODELS_DIR, filename) | |
| if os.path.isfile(dest): | |
| return dest | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| path = hf_hub_download( | |
| repo_id=GGUF_HF_REPO, | |
| filename=filename, | |
| local_dir=MODELS_DIR, | |
| ) | |
| return path | |
| except Exception as exc: | |
| logger.error("Failed to download %s: %s", filename, exc) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # LoRA listing for UI dropdowns | |
| # --------------------------------------------------------------------------- | |
| def _list_lora_choices(): | |
| """Return list of LoRA choices for dropdown, including 'None'.""" | |
| choices = ["None (no LoRA)"] | |
| if os.path.isdir(ADAPTER_DIR): | |
| for d in os.listdir(ADAPTER_DIR): | |
| if os.path.isdir(os.path.join(ADAPTER_DIR, d)): | |
| choices.append(d) | |
| return choices | |
| # --------------------------------------------------------------------------- | |
| # ace-server stop/start helpers | |
| # --------------------------------------------------------------------------- | |
| _ace_proc = None | |
| def _stop_ace_server(): | |
| """Stop ace-server process.""" | |
| global _ace_proc | |
| logger.info("[ace-server] Stopping...") | |
| if _ace_proc and _ace_proc.poll() is None: | |
| _ace_proc.terminate() | |
| try: | |
| _ace_proc.wait(timeout=10) | |
| except subprocess.TimeoutExpired: | |
| _ace_proc.kill() | |
| _ace_proc = None | |
| logger.info("[ace-server] Stopped (tracked PID)") | |
| else: | |
| try: | |
| subprocess.run(["pkill", "ace-server"], | |
| stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL, | |
| timeout=10) | |
| logger.info("[ace-server] Stopped (pkill)") | |
| except Exception: | |
| pass | |
| time.sleep(1) | |
| def _start_ace_server(max_retries: int = 3, retry_delay: float = 5.0): | |
| """Start ace-server in background and wait for health. | |
| Retries up to max_retries times with retry_delay seconds between attempts. | |
| """ | |
| global _ace_proc | |
| for attempt in range(1, max_retries + 1): | |
| logger.info( | |
| "[ace-server] Starting (attempt %d/%d) with --adapters %s", | |
| attempt, max_retries, ADAPTER_DIR, | |
| ) | |
| try: | |
| _ace_proc = subprocess.Popen( | |
| [ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085", | |
| "--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"], | |
| ) | |
| except Exception as exc: | |
| logger.error("[ace-server] Failed to start: %s", exc) | |
| if attempt < max_retries: | |
| time.sleep(retry_delay) | |
| continue | |
| return False | |
| for _ in range(30): | |
| if _server_ok(): | |
| logger.info("[ace-server] Healthy") | |
| return True | |
| time.sleep(2) | |
| logger.warning("[ace-server] Health check timeout on attempt %d/%d", attempt, max_retries) | |
| # Kill the failed process before retrying | |
| if _ace_proc and _ace_proc.poll() is None: | |
| _ace_proc.kill() | |
| try: | |
| _ace_proc.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| pass | |
| if attempt < max_retries: | |
| time.sleep(retry_delay) | |
| logger.error("[ace-server] Failed to start after %d attempts", max_retries) | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # CLI mode | |
| # --------------------------------------------------------------------------- | |
| def cli_main(): | |
| parser = argparse.ArgumentParser( | |
| description="ACE-Step 1.5 XL (CPU) - CLI inference via ace-server", | |
| ) | |
| parser.add_argument("caption", nargs="?", default="upbeat electronic dance music", | |
| help="Music description / caption") | |
| parser.add_argument("--lyrics", "-l", default="[Instrumental]", | |
| help="Lyrics text (use '[Instrumental]' for no vocals)") | |
| parser.add_argument("--bpm", type=int, default=120, help="Beats per minute") | |
| parser.add_argument("--duration", "-d", type=float, default=10, | |
| help="Duration in seconds (max 300)") | |
| parser.add_argument("--steps", "-s", type=int, default=8, | |
| help="Inference steps (1-32)") | |
| parser.add_argument("--seed", type=int, default=-1, | |
| help="Random seed (-1 for random)") | |
| parser.add_argument("--format", "-f", choices=["wav", "mp3"], default="wav", | |
| help="Output audio format") | |
| parser.add_argument("--adapter", "-a", default=None, | |
| help="LoRA adapter name") | |
| parser.add_argument("-o", "--output", default=None, | |
| help="Output file path (default: auto in outputs dir)") | |
| parser.add_argument("--server", default=None, | |
| help="ace-server URL (default: http://127.0.0.1:8085)") | |
| args = parser.parse_args() | |
| if args.server: | |
| global ACE_SERVER | |
| ACE_SERVER = args.server | |
| if not _server_ok(): | |
| print(f"ERROR: ace-server not reachable at {ACE_SERVER}", file=sys.stderr) | |
| sys.exit(1) | |
| seed = args.seed if args.seed >= 0 else None | |
| def cli_progress(phase, data): | |
| phases = { | |
| "lm_submit": "Submitting LM job...", | |
| "lm_poll": f"LM generating (job {data['job_id']})..." if data else "LM generating...", | |
| "synth_submit": "Submitting synth job...", | |
| "synth_poll": f"Synthesizing (job {data['job_id']})..." if data else "Synthesizing...", | |
| "fetch": "Fetching audio...", | |
| } | |
| msg = phases.get(phase, phase) | |
| print(f" [{phase}] {msg}") | |
| print(f"ACE-Step CLI | caption: {args.caption}") | |
| print(f" lyrics: {args.lyrics} | bpm: {args.bpm} | duration: {args.duration}s " | |
| f"| steps: {args.steps} | seed: {args.seed} | format: {args.format}") | |
| try: | |
| audio_path, status = _run_pipeline( | |
| caption=args.caption, | |
| lyrics=args.lyrics, | |
| bpm=args.bpm, | |
| duration=args.duration, | |
| seed=seed, | |
| steps=args.steps, | |
| output_format=args.format, | |
| adapter=args.adapter, | |
| progress_cb=cli_progress, | |
| ) | |
| except RuntimeError as e: | |
| print(f"ERROR: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| # Move to requested output path if specified | |
| if args.output: | |
| out_dir = os.path.dirname(os.path.abspath(args.output)) | |
| os.makedirs(out_dir, exist_ok=True) | |
| shutil.move(audio_path, args.output) | |
| audio_path = args.output | |
| print(f" {status}") | |
| print(f" Output: {audio_path}") | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI mode | |
| # --------------------------------------------------------------------------- | |
| def gradio_main(): | |
| import gradio as gr | |
| import gc | |
| # -- Persistent training log buffer (survives across yields) -- | |
| _train_log_lines = [] | |
| # -- Generate tab handler -- | |
| def generate_music(caption, lyrics, bpm, duration, seed, | |
| steps, lora_select, lm_model_select, | |
| progress=gr.Progress(track_tqdm=True)): | |
| if not _training_lock.acquire(blocking=False): | |
| return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training." | |
| _training_lock.release() | |
| if not _server_ok(): | |
| return None, "ace-server not running. Check logs." | |
| if not lyrics or lyrics.strip() == "": | |
| lyrics = "[Instrumental]" | |
| actual_seed = None if seed is None or int(seed) < 0 else int(seed) | |
| adapter = None if lora_select == "None (no LoRA)" else lora_select | |
| lm_model_file = lm_model_select.replace(" [not installed]", "") if lm_model_select else None | |
| if lm_model_file and "[not installed]" in (lm_model_select or ""): | |
| _download_lm_model(lm_model_file) | |
| lm_model = lm_model_file | |
| progress_map = { | |
| "lm_submit": (0.05, "Submitting LM job..."), | |
| "lm_poll": (0.10, "LM generating..."), | |
| "synth_submit": (0.40, "Submitting synth job..."), | |
| "synth_poll": (0.50, "Synthesizing audio..."), | |
| "fetch": (0.90, "Fetching audio..."), | |
| } | |
| def gr_progress(phase, data): | |
| pct, desc = progress_map.get(phase, (0.5, phase)) | |
| if data and "job_id" in data: | |
| desc += f" (job {data['job_id']})" | |
| progress(pct, desc=desc) | |
| try: | |
| audio_path, status = _run_pipeline( | |
| caption=caption, | |
| lyrics=lyrics, | |
| bpm=bpm, | |
| duration=duration, | |
| seed=actual_seed, | |
| steps=steps, | |
| output_format="mp3", | |
| adapter=adapter, | |
| lm_model=lm_model, | |
| progress_cb=gr_progress, | |
| ) | |
| return audio_path, status | |
| except RuntimeError as e: | |
| return None, str(e) | |
| except Exception as e: | |
| return None, f"Unexpected error: {e}" | |
| # -- Server info helper -- | |
| def get_server_status(): | |
| if not _server_ok(): | |
| return "ace-server: OFFLINE" | |
| props = _get_props() | |
| lines = ["ace-server: ONLINE"] | |
| if props: | |
| lines.append(json.dumps(props, indent=2)) | |
| return "\n".join(lines) | |
| # -- Training generator (direct integration, no subprocess) -- | |
| def train_lora_ui(audio_files, lora_name, epochs, lr, rank, use_lm_caption): | |
| """Generator that yields (train_log, train_btn_update, cancel_btn_update).""" | |
| import gc as _gc | |
| _train_log_lines.clear() | |
| train_start = time.time() | |
| def _log(msg): | |
| elapsed = int(time.time() - train_start) | |
| m, s = divmod(elapsed, 60) | |
| h, m = divmod(m, 60) | |
| ts = f"+{h}:{m:02d}:{s:02d}" if h else f"+{m:02d}:{s:02d}" | |
| line = f"[{ts}] {msg}" | |
| _train_log_lines.append(line) | |
| logger.info(msg) | |
| if len(_train_log_lines) > 2000: | |
| _train_log_lines[:] = _train_log_lines[-1000:] | |
| def _log_text(): | |
| return "\n".join(_train_log_lines) | |
| # -- Validation -- | |
| if not audio_files: | |
| _log("[FAIL] No audio files uploaded.") | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| return | |
| if len(audio_files) > MAX_AUDIO_FILES: | |
| _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}") | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| return | |
| lora_name = (lora_name or "").strip() or "my-lora" | |
| lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name) | |
| epochs = max(1, min(int(epochs), 1000)) | |
| lr = float(lr) | |
| rank = max(1, min(int(rank), 128)) | |
| work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name) | |
| os.makedirs(work_dir, exist_ok=True) | |
| audio_dir = os.path.join(work_dir, "audio_input") | |
| if os.path.exists(audio_dir): | |
| shutil.rmtree(audio_dir) | |
| os.makedirs(audio_dir) | |
| adapter_out = os.path.join(ADAPTER_DIR, lora_name) | |
| os.makedirs(adapter_out, exist_ok=True) | |
| # Copy uploaded audio files + check total duration | |
| _log(f"[INFO] Preparing {len(audio_files)} audio files...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| import librosa as _lr | |
| total_dur = 0.0 | |
| accepted = 0 | |
| skipped_names = [] | |
| truncated_names = [] | |
| for f in audio_files: | |
| src = f.name if hasattr(f, "name") else str(f) | |
| fname = os.path.basename(src) | |
| # .txt/.json sidecars: copy as caption files, skip duration check | |
| if fname.lower().endswith((".txt", ".json")): | |
| shutil.copy2(src, os.path.join(audio_dir, fname)) | |
| continue | |
| try: | |
| dur = _lr.get_duration(path=src) | |
| except Exception: | |
| dur = 0.0 | |
| if dur <= 0: | |
| skipped_names.append(f"{fname} (invalid/empty)") | |
| continue | |
| remaining = MAX_TOTAL_AUDIO - total_dur | |
| if remaining <= 0: | |
| skipped_names.append(fname) | |
| continue | |
| if dur > remaining: | |
| # Truncate this file to fit | |
| import soundfile as _sf | |
| y, sr = _lr.load(src, sr=None, mono=False) | |
| max_samples = int(remaining * sr) | |
| if y.ndim == 1: | |
| y = y[:max_samples] | |
| else: | |
| y = y[:, :max_samples] | |
| dst = os.path.join(audio_dir, fname) | |
| _sf.write(dst, y.T if y.ndim > 1 else y, sr) | |
| truncated_names.append(f"{fname} ({dur:.0f}s -> {remaining:.0f}s)") | |
| total_dur += remaining | |
| accepted += 1 | |
| else: | |
| shutil.copy2(src, os.path.join(audio_dir, fname)) | |
| total_dur += dur | |
| accepted += 1 | |
| if truncated_names: | |
| _log(f"[WARN] Truncated: {', '.join(truncated_names)}") | |
| if skipped_names: | |
| _log(f"[WARN] Skipped (over {MAX_TOTAL_AUDIO/60:.0f} min cap): {', '.join(skipped_names)}") | |
| _log(f"[INFO] Total audio: {total_dur:.0f}s ({total_dur/60:.1f} min), {accepted} files") | |
| _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | " | |
| f"Epochs: {epochs} | LR: {lr} | Rank: {rank}") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| # Caption audio files without user-provided sidecars | |
| audio_to_caption = [] | |
| for audio_fname in sorted(os.listdir(audio_dir)): | |
| full_path = os.path.join(audio_dir, audio_fname) | |
| if not os.path.isfile(full_path): | |
| continue | |
| ext = audio_fname.lower().rsplit(".", 1)[-1] if "." in audio_fname else "" | |
| if ext in ("json", "txt"): | |
| continue | |
| stem = audio_fname.rsplit(".", 1)[0] if "." in audio_fname else audio_fname | |
| sidecar_json = os.path.join(audio_dir, stem + ".json") | |
| sidecar_txt = os.path.join(audio_dir, stem + ".txt") | |
| if os.path.isfile(sidecar_json) or os.path.isfile(sidecar_txt): | |
| _log(f" {audio_fname}: using caption file") | |
| continue | |
| audio_to_caption.append((audio_fname, full_path, sidecar_json)) | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| if audio_to_caption and use_lm_caption and _server_ok(): | |
| # --- Mode: GGUF LM captioning (best quality, 5h timeout per file) --- | |
| LM_TIMEOUT = 18000 # 5h per file | |
| est_total = int(total_dur * 7 + len(audio_to_caption) * 600) | |
| if est_total > LM_TIMEOUT: | |
| _log(f"[WARN] Estimated {est_total // 60} min exceeds 5h, switching to fast captioning") | |
| use_lm_caption = False | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| else: | |
| _log(f"[INFO] LM captioning {len(audio_to_caption)} files (5h timeout per file)...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| for audio_fname, full_path, sidecar_json in audio_to_caption: | |
| if _training_cancel.is_set(): | |
| break | |
| _log(f" {audio_fname}: LM captioning...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| caption_data = _caption_via_understand( | |
| full_path, timeout=LM_TIMEOUT, | |
| cancel_check=lambda: _training_cancel.is_set(), | |
| ) | |
| if caption_data: | |
| bpm_s = caption_data.get("bpm", "?") | |
| key_s = caption_data.get("keyscale", caption_data.get("key", "?")) | |
| _log(f" {audio_fname}: OK (BPM={bpm_s}, key={key_s})") | |
| with open(sidecar_json, "w") as cj: | |
| json.dump(caption_data, cj) | |
| else: | |
| _log(f" {audio_fname}: LM failed") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| if audio_to_caption and not use_lm_caption: | |
| # --- Mode: Fast captioning (CLAP + Whisper + librosa) --- | |
| _log(f"[INFO] Fast captioning {len(audio_to_caption)} files " | |
| f"(CLAP tags + lyrics + BPM)...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| try: | |
| from caption_fast import caption_audio, unload_caption_models | |
| for audio_fname, full_path, sidecar_json in audio_to_caption: | |
| if _training_cancel.is_set(): | |
| break | |
| _log(f" {audio_fname}: analyzing...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| try: | |
| result = caption_audio(full_path) | |
| _log(f" {audio_fname}: {result.get('caption', '')[:60]}") | |
| if result.get("lyrics") and result["lyrics"] != "[Instrumental]": | |
| _log(f" {audio_fname}: lyrics extracted ({len(result['lyrics'])} chars)") | |
| with open(sidecar_json, "w") as cj: | |
| json.dump(result, cj) | |
| except Exception as cap_exc: | |
| _log(f" {audio_fname}: fast caption failed: {cap_exc}") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| unload_caption_models() | |
| _gc.collect() | |
| except ImportError: | |
| _log("[WARN] Fast captioning not available, using librosa fallback") | |
| for audio_fname, full_path, sidecar_json in audio_to_caption: | |
| try: | |
| y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True) | |
| tempo_arr, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap) | |
| bpm_val = int(round(float( | |
| tempo_arr.item() if hasattr(tempo_arr, 'item') else tempo_arr))) | |
| fallback = {"caption": audio_fname.rsplit(".", 1)[0], | |
| "bpm": str(bpm_val), "key": "", "signature": "4/4", | |
| "lyrics": "[Instrumental]"} | |
| with open(sidecar_json, "w") as cj: | |
| json.dump(fallback, cj) | |
| _log(f" {audio_fname}: librosa BPM={bpm_val}") | |
| except Exception as exc: | |
| _log(f" {audio_fname}: failed: {exc}") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| if _training_cancel.is_set(): | |
| _training_cancel.clear() | |
| _log("[CANCELLED] Stopped") | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| shutil.rmtree(work_dir, ignore_errors=True) | |
| return | |
| # Stop ace-server before training (frees memory) | |
| _training_lock.acquire() | |
| _log("[INFO] Stopping ace-server for training...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| _stop_ace_server() | |
| _gc.collect() | |
| _cleanup_done = False | |
| try: | |
| # -- Phase 1: Preprocessing (runs in thread for live progress) -- | |
| preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors") | |
| _preprocess_log_len = len(_train_log_lines) | |
| def preprocess_progress(current, total, desc): | |
| _log(f" {desc} ({current}/{total})") | |
| _preprocess_result = [None] | |
| _preprocess_error = [None] | |
| def _run_preprocess(): | |
| try: | |
| _preprocess_result[0] = preprocess_audio( | |
| audio_dir=audio_dir, | |
| output_dir=preprocessed_dir, | |
| checkpoint_dir=ACE_CHECKPOINT_DIR, | |
| device="cpu", | |
| variant="turbo", | |
| max_duration=float(MAX_TOTAL_AUDIO), | |
| progress_callback=preprocess_progress, | |
| cancel_check=lambda: _training_cancel.is_set(), | |
| ) | |
| except Exception as exc: | |
| _preprocess_error[0] = exc | |
| _log("[Step 1/2] Encoding audio → training data (VAE + text encoder)...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| t = threading.Thread(target=_run_preprocess, daemon=True) | |
| t.start() | |
| while t.is_alive(): | |
| t.join(timeout=3) | |
| if _training_cancel.is_set(): | |
| _training_cancel.clear() | |
| _log("[CANCELLED] Stopped during preprocessing") | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| return | |
| if len(_train_log_lines) > _preprocess_log_len: | |
| _preprocess_log_len = len(_train_log_lines) | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| if _preprocess_error[0]: | |
| raise _preprocess_error[0] | |
| result = _preprocess_result[0] | |
| processed = result.get("processed", 0) | |
| failed = result.get("failed", 0) | |
| total = result.get("total", 0) | |
| _log(f"[OK] Preprocessed: {processed}/{total} files (failed: {failed})") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| if processed == 0: | |
| _log("[FAIL] No files preprocessed successfully. Cannot train.") | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| return | |
| _gc.collect() | |
| # -- Phase 2: Training (random 60s crops for speed + augmentation) -- | |
| _log("[Step 2/2] Training LoRA...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| for msg in train_lora_generator( | |
| dataset_dir=preprocessed_dir, | |
| output_dir=adapter_out, | |
| checkpoint_dir=ACE_CHECKPOINT_DIR, | |
| epochs=epochs, | |
| lr=lr, | |
| rank=rank, | |
| alpha=rank * 2, | |
| dropout=0.0, | |
| batch_size=1, | |
| gradient_accumulation_steps=4, | |
| warmup_steps=100, | |
| weight_decay=0.01, | |
| max_grad_norm=1.0, | |
| save_every_n_epochs=0, | |
| seed=42, | |
| variant="turbo", | |
| device="cpu", | |
| chunk_duration=60, | |
| log_every=5, | |
| ): | |
| elapsed = time.time() - train_start | |
| if elapsed > MAX_TRAINING_TIME: | |
| _log(f"[WARN] Training timed out after {int(elapsed)}s") | |
| cancel_training() | |
| break | |
| _log(msg) | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| if msg.strip() == "[DONE]": | |
| break | |
| _log(f"[INFO] Total time: {time.time() - train_start:.0f}s") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| except GeneratorExit: | |
| _training_cancel.set() | |
| logger.info("Generator closed by Gradio, cleaning up") | |
| _cleanup_done = True | |
| _training_lock.release() | |
| _gc.collect() | |
| _start_ace_server() | |
| shutil.rmtree(work_dir, ignore_errors=True) | |
| return | |
| except Exception as exc: | |
| _log(f"[FAIL] Training error: {exc}") | |
| import traceback | |
| _log(traceback.format_exc()) | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| finally: | |
| if not _cleanup_done: | |
| _training_lock.release() | |
| _log("[INFO] Restarting ace-server...") | |
| yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File() | |
| _gc.collect() | |
| ok = _start_ace_server() | |
| if ok: | |
| _log("[OK] ace-server restarted successfully") | |
| else: | |
| _log("[WARN] ace-server may not have restarted -- check logs") | |
| adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors") | |
| if os.path.isfile(adapter_safetensors): | |
| import zipfile | |
| tmp_zip = tempfile.NamedTemporaryFile( | |
| suffix=".zip", | |
| prefix=f"{lora_name}_", | |
| delete=False, | |
| ) | |
| tmp_zip.close() | |
| with zipfile.ZipFile(tmp_zip.name, "w", zipfile.ZIP_DEFLATED) as zf: | |
| zf.write(adapter_safetensors, f"{lora_name}/adapter_model.safetensors") | |
| adapter_config = os.path.join(adapter_out, "adapter_config.json") | |
| if os.path.isfile(adapter_config): | |
| zf.write(adapter_config, f"{lora_name}/adapter_config.json") | |
| # Include generated captions if they exist | |
| caption_count = 0 | |
| if os.path.isdir(audio_dir): | |
| for cf in sorted(os.listdir(audio_dir)): | |
| if cf.endswith(".json"): | |
| zf.write(os.path.join(audio_dir, cf), | |
| f"{lora_name}/captions/{cf}") | |
| caption_count += 1 | |
| _log(f"[OK] LoRA saved: {lora_name}" + | |
| (f" ({caption_count} captions included)" if caption_count else "")) | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_zip.name, visible=True) | |
| else: | |
| yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File() | |
| shutil.rmtree(work_dir, ignore_errors=True) | |
| # -- Cancel handler -- | |
| def _on_cancel(): | |
| cancel_training() | |
| logger.info("Cancel requested by user") | |
| return "Cancelling..." | |
| # -- Build LM model choices -- | |
| def _lm_model_choices(): | |
| return _scan_lm_models() | |
| # -- Build UI -- | |
| CSS = """ | |
| .compact-row { gap: 8px !important; } | |
| .status-box textarea { font-family: monospace; font-size: 13px; overflow-y: auto !important; } | |
| """ | |
| with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo: | |
| with gr.Tabs(): | |
| # ============================================================ | |
| # Tab 1: Generate Music | |
| # ============================================================ | |
| with gr.Tab("Generate Music"): | |
| gr.Markdown("**[ACE-Step 1.5 XL](https://github.com/ace-step/ACE-Step-1.5)** GGUF Q4_K_M via [acestep.cpp](https://github.com/ServeurpersoCom/acestep.cpp) | ~5 min for 10s audio") | |
| with gr.Row(elem_classes="compact-row"): | |
| with gr.Column(scale=3): | |
| audio_out = gr.Audio(label="Output", type="filepath") | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=1, | |
| elem_classes="status-box", | |
| ) | |
| caption = gr.Textbox( | |
| label="Music Description", | |
| lines=2, | |
| value="upbeat electronic dance music, energetic synth leads", | |
| ) | |
| lyrics = gr.Textbox( | |
| label="Lyrics", | |
| lines=3, | |
| value="[Instrumental]", | |
| placeholder="Enter lyrics here, or leave empty for instrumental (no vocals)", | |
| ) | |
| with gr.Column(scale=2): | |
| gen_btn = gr.Button("Generate Music", variant="primary") | |
| with gr.Row(elem_classes="compact-row"): | |
| bpm = gr.Number(label="BPM", value=120, minimum=0, maximum=300) | |
| seed = gr.Number(label="Seed (-1=random)", value=-1) | |
| with gr.Row(elem_classes="compact-row"): | |
| duration = gr.Slider( | |
| label="Duration (s)", minimum=10, maximum=120, | |
| value=10, step=5, | |
| ) | |
| steps = gr.Slider( | |
| label="Steps (8 for turbo)", minimum=1, maximum=32, | |
| value=8, step=1, interactive=False, | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| lora_select = gr.Dropdown( | |
| label="LoRA", choices=_list_lora_choices(), | |
| value="None (no LoRA)", | |
| allow_custom_value=True, | |
| ) | |
| lm_model_select = gr.Dropdown( | |
| label="LM Model", choices=_lm_model_choices(), | |
| value=DEFAULT_LM, | |
| ) | |
| gen_btn.click( | |
| fn=generate_music, | |
| inputs=[caption, lyrics, bpm, duration, | |
| seed, steps, lora_select, lm_model_select], | |
| outputs=[audio_out, status], | |
| api_name="generate", | |
| ) | |
| # ============================================================ | |
| # Tab 2: Train LoRA | |
| # ============================================================ | |
| with gr.Tab("Train LoRA"): | |
| gr.Markdown("LoRA training ported from [Side-Step](https://github.com/koda-dernet/Side-Step) | Model: [ACE-Step 1.5](https://github.com/ace-step/ACE-Step-1.5) | ~8h for 3 files @ 200 epochs") | |
| with gr.Row(elem_classes="compact-row"): | |
| with gr.Column(scale=3): | |
| train_log = gr.Textbox( | |
| label="Training Log", | |
| interactive=False, | |
| lines=12, | |
| max_lines=50, | |
| autoscroll=True, | |
| elem_classes="status-box", | |
| ) | |
| train_output_file = gr.File(label="Trained LoRA (download)", visible=False) | |
| train_audio = gr.File( | |
| label="Training Audio — max 30 min total, ~2 min/epoch on CPU (optional caption .txt)", | |
| file_count="multiple", | |
| file_types=["audio", ".txt", ".json"], | |
| height=120, | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Row(elem_classes="compact-row"): | |
| train_btn = gr.Button("Train", variant="primary", scale=2) | |
| cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1) | |
| lora_name = gr.Textbox(label="LoRA Name", value="my-lora") | |
| train_epochs = gr.Slider( | |
| label="Epochs (200 recommended ~6h on CPU, best 500)", | |
| minimum=1, maximum=1000, | |
| value=200, step=1, | |
| ) | |
| train_lr = gr.Number(label="Learning Rate", value=3e-4) | |
| train_rank = gr.Slider( | |
| label="Rank (r)", minimum=1, maximum=128, | |
| value=16, step=1, | |
| ) | |
| use_lm_caption = gr.Checkbox( | |
| label="Use LM captioning (best quality, ~30 min/file)", | |
| value=False, | |
| ) | |
| # Button swap on click (separate handler, like rvc-beatrice) | |
| # This fires immediately so user sees Cancel even if training | |
| # queues behind concurrency_limit=1 | |
| train_btn.click( | |
| lambda: (gr.Button(visible=False), gr.Button(visible=True)), | |
| outputs=[train_btn, cancel_btn], | |
| ) | |
| # Training generator -- yields (log, train_btn, cancel_btn, output_file) | |
| train_event = train_btn.click( | |
| train_lora_ui, | |
| inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank, use_lm_caption], | |
| outputs=[train_log, train_btn, cancel_btn, train_output_file], | |
| api_name="train_lora", | |
| concurrency_limit=1, | |
| ) | |
| # After training completes, restore buttons and refresh LoRA dropdown | |
| # This ensures cleanup even if the user navigated away | |
| def _post_training(): | |
| return ( | |
| gr.Button(visible=True), | |
| gr.Button(visible=False), | |
| gr.Dropdown(choices=_list_lora_choices()), | |
| ) | |
| train_event.then( | |
| _post_training, | |
| outputs=[train_btn, cancel_btn, lora_select], | |
| ) | |
| # Cancel: set the flag, update status | |
| cancel_btn.click( | |
| _on_cancel, | |
| outputs=[train_log], | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| mcp_server=True, | |
| css=CSS, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # If any CLI arguments besides the script name, run CLI mode | |
| # (Gradio sets no extra args; start.sh calls `python3 /app/app.py`) | |
| if len(sys.argv) > 1: | |
| cli_main() | |
| else: | |
| gradio_main() | |