| | |
| | """ |
| | Pipeline 包装脚本 |
| | |
| | 此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。 |
| | 主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。 |
| | |
| | 进度消息格式: |
| | ##PROGRESS##{"type": "progress", "stage": "...", ...}## |
| | |
| | Usage: |
| | python run_pipeline.py --config /path/to/config.json --task-id task-123 |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import sys |
| | import os |
| | import traceback |
| | from datetime import datetime |
| | from typing import Dict, Any |
| |
|
| | |
| | from pathlib import Path |
| | _SCRIPT_DIR = Path(__file__).parent.resolve() |
| | _API_SERVER_ROOT = _SCRIPT_DIR.parent.parent |
| | _PROJECT_ROOT = _API_SERVER_ROOT.parent |
| | sys.path.insert(0, str(_PROJECT_ROOT)) |
| |
|
| | |
| | from project_config import settings, PROJECT_ROOT, get_pythonpath |
| |
|
| |
|
| | |
| | PROGRESS_PREFIX = "##PROGRESS##" |
| | PROGRESS_SUFFIX = "##" |
| |
|
| |
|
| | def emit_progress(progress_info: Dict[str, Any]) -> None: |
| | """ |
| | 输出进度消息到 stdout |
| | |
| | Args: |
| | progress_info: 进度信息字典 |
| | """ |
| | |
| | if "timestamp" not in progress_info: |
| | progress_info["timestamp"] = datetime.utcnow().isoformat() |
| | |
| | json_str = json.dumps(progress_info, ensure_ascii=False) |
| | print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True) |
| |
|
| |
|
| | def emit_log(level: str, message: str, **extra) -> None: |
| | """ |
| | 输出日志消息 |
| | |
| | Args: |
| | level: 日志级别 (info, warning, error) |
| | message: 日志消息 |
| | **extra: 额外数据 |
| | """ |
| | emit_progress({ |
| | "type": "log", |
| | "level": level, |
| | "message": message, |
| | **extra |
| | }) |
| |
|
| |
|
| | def load_config(config_path: str) -> Dict[str, Any]: |
| | """ |
| | 加载配置文件 |
| | |
| | Args: |
| | config_path: 配置文件路径 |
| | |
| | Returns: |
| | 配置字典 |
| | """ |
| | with open(config_path, 'r', encoding='utf-8') as f: |
| | return json.load(f) |
| |
|
| |
|
| | def build_pipeline(config: Dict[str, Any]): |
| | """ |
| | 根据配置构建 TrainingPipeline |
| | |
| | Args: |
| | config: 配置字典,包含: |
| | - exp_name: 实验名称 |
| | - version: 模型版本 |
| | - stages: 要执行的阶段列表 |
| | - 各阶段的具体配置 |
| | |
| | Returns: |
| | TrainingPipeline 实例 |
| | """ |
| | from training_pipeline import ( |
| | TrainingPipeline, |
| | ModelVersion, |
| | |
| | AudioSliceConfig, |
| | ASRConfig, |
| | DenoiseConfig, |
| | FeatureExtractionConfig, |
| | SoVITSTrainConfig, |
| | GPTTrainConfig, |
| | InferenceConfig, |
| | |
| | AudioSliceStage, |
| | ASRStage, |
| | DenoiseStage, |
| | TextFeatureStage, |
| | HuBERTFeatureStage, |
| | SemanticTokenStage, |
| | SoVITSTrainStage, |
| | GPTTrainStage, |
| | InferenceStage, |
| | ) |
| | |
| | pipeline = TrainingPipeline() |
| | |
| | exp_name = config["exp_name"] |
| | version_str = config.get("version", "v2") |
| | version = ModelVersion(version_str) if isinstance(version_str, str) else version_str |
| | |
| | |
| | base_params = { |
| | "exp_name": exp_name, |
| | "exp_root": config.get("exp_root", "logs"), |
| | "gpu_numbers": config.get("gpu_numbers", "0"), |
| | "is_half": config.get("is_half", True), |
| | } |
| | |
| | |
| | stage_builders = { |
| | "audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig( |
| | **base_params, |
| | input_path=cfg.get("input_path", ""), |
| | threshold=cfg.get("threshold", -34), |
| | min_length=cfg.get("min_length", 4000), |
| | min_interval=cfg.get("min_interval", 300), |
| | hop_size=cfg.get("hop_size", 10), |
| | max_sil_kept=cfg.get("max_sil_kept", 500), |
| | max_amp=cfg.get("max_amp", 0.9), |
| | alpha=cfg.get("alpha", 0.25), |
| | n_parts=cfg.get("n_parts", 4), |
| | )), |
| | |
| | "asr": lambda cfg: ASRStage(ASRConfig( |
| | **base_params, |
| | model=cfg.get("model", "达摩 ASR (中文)"), |
| | model_size=cfg.get("model_size", "large"), |
| | language=cfg.get("language", "zh"), |
| | precision=cfg.get("precision", "float32"), |
| | )), |
| | |
| | "denoise": lambda cfg: DenoiseStage(DenoiseConfig( |
| | **base_params, |
| | input_dir=cfg.get("input_dir", ""), |
| | output_dir=cfg.get("output_dir", "output/denoise_opt"), |
| | )), |
| | |
| | "text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig( |
| | **base_params, |
| | version=version, |
| | bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
| | "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
| | ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
| | "GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
| | pretrained_s2G=cfg.get("pretrained_s2G", |
| | "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| | )), |
| | |
| | "hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig( |
| | **base_params, |
| | version=version, |
| | bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
| | "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
| | ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
| | "GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
| | pretrained_s2G=cfg.get("pretrained_s2G", |
| | "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| | )), |
| | |
| | "semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig( |
| | **base_params, |
| | version=version, |
| | bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
| | "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
| | ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
| | "GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
| | pretrained_s2G=cfg.get("pretrained_s2G", |
| | "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| | )), |
| | |
| | "sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig( |
| | **base_params, |
| | version=version, |
| | batch_size=cfg.get("batch_size", 4), |
| | total_epoch=cfg.get("total_epoch", 8), |
| | text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4), |
| | save_every_epoch=cfg.get("save_every_epoch", 4), |
| | if_save_latest=cfg.get("if_save_latest", True), |
| | if_save_every_weights=cfg.get("if_save_every_weights", True), |
| | pretrained_s2G=cfg.get("pretrained_s2G", |
| | "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
| | pretrained_s2D=cfg.get("pretrained_s2D", |
| | "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"), |
| | if_grad_ckpt=cfg.get("if_grad_ckpt", False), |
| | lora_rank=cfg.get("lora_rank", 32), |
| | )), |
| | |
| | "gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig( |
| | **base_params, |
| | version=version, |
| | batch_size=cfg.get("batch_size", 4), |
| | total_epoch=cfg.get("total_epoch", 15), |
| | save_every_epoch=cfg.get("save_every_epoch", 5), |
| | if_save_latest=cfg.get("if_save_latest", True), |
| | if_save_every_weights=cfg.get("if_save_every_weights", True), |
| | if_dpo=cfg.get("if_dpo", False), |
| | pretrained_s1=cfg.get("pretrained_s1", |
| | "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"), |
| | )), |
| | |
| | "inference": lambda cfg: InferenceStage(InferenceConfig( |
| | **base_params, |
| | version=version, |
| | gpt_path=cfg.get("gpt_path", ""), |
| | sovits_path=cfg.get("sovits_path", ""), |
| | ref_text=cfg.get("ref_text", ""), |
| | ref_audio_path=cfg.get("ref_audio_path", ""), |
| | target_text=cfg.get("target_text", ""), |
| | text_split_method=cfg.get("text_split_method", "cut1"), |
| | )), |
| | } |
| | |
| | |
| | |
| | |
| | |
| | stages = config.get("stages", []) |
| | for stage_item in stages: |
| | |
| | if isinstance(stage_item, str): |
| | stage_type = stage_item |
| | stage_config = config |
| | elif isinstance(stage_item, dict): |
| | stage_type = stage_item.get("type") |
| | |
| | stage_config = {**config, **stage_item} |
| | else: |
| | emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}") |
| | continue |
| | |
| | if stage_type in stage_builders: |
| | stage = stage_builders[stage_type](stage_config) |
| | pipeline.add_stage(stage) |
| | emit_log("info", f"已添加阶段: {stage.name}") |
| | else: |
| | emit_log("warning", f"未知阶段类型: {stage_type}") |
| | |
| | return pipeline |
| |
|
| |
|
| | def run_pipeline(config: Dict[str, Any], task_id: str) -> bool: |
| | """ |
| | 执行 Pipeline |
| | |
| | Args: |
| | config: 配置字典 |
| | task_id: 任务ID |
| | |
| | Returns: |
| | 是否成功完成 |
| | """ |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "running", |
| | "message": "正在初始化训练流水线...", |
| | "task_id": task_id, |
| | "progress": 0.0, |
| | "overall_progress": 0.0, |
| | }) |
| | |
| | try: |
| | pipeline = build_pipeline(config) |
| | |
| | stages = pipeline.get_stages() |
| | if not stages: |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "failed", |
| | "message": "没有配置任何训练阶段", |
| | "task_id": task_id, |
| | }) |
| | return False |
| | |
| | emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段") |
| | |
| | |
| | for progress in pipeline.run(): |
| | |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "running", |
| | "stage": progress.get("stage"), |
| | "stage_index": progress.get("stage_index"), |
| | "total_stages": progress.get("total_stages"), |
| | "progress": progress.get("progress", 0.0), |
| | "overall_progress": progress.get("overall_progress", 0.0), |
| | "message": progress.get("message"), |
| | "task_id": task_id, |
| | "data": progress.get("data", {}), |
| | }) |
| | |
| | |
| | if progress.get("status") == "failed": |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "failed", |
| | "stage": progress.get("stage"), |
| | "message": progress.get("message", "阶段执行失败"), |
| | "task_id": task_id, |
| | }) |
| | return False |
| | |
| | |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "completed", |
| | "message": "训练流水线执行完成", |
| | "task_id": task_id, |
| | "progress": 1.0, |
| | "overall_progress": 1.0, |
| | }) |
| | return True |
| | |
| | except Exception as e: |
| | error_msg = str(e) |
| | error_trace = traceback.format_exc() |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "failed", |
| | "message": f"执行出错: {error_msg}", |
| | "error": error_msg, |
| | "traceback": error_trace, |
| | "task_id": task_id, |
| | }) |
| | return False |
| |
|
| |
|
| | def main(): |
| | """主函数""" |
| | parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线") |
| | parser.add_argument("--config", required=True, help="配置文件路径 (JSON)") |
| | parser.add_argument("--task-id", required=True, help="任务ID") |
| | |
| | args = parser.parse_args() |
| | |
| | emit_log("info", f"启动训练任务: {args.task_id}") |
| | emit_log("info", f"配置文件: {args.config}") |
| | |
| | try: |
| | config = load_config(args.config) |
| | except Exception as e: |
| | emit_progress({ |
| | "type": "progress", |
| | "status": "failed", |
| | "message": f"加载配置文件失败: {e}", |
| | "task_id": args.task_id, |
| | }) |
| | sys.exit(1) |
| | |
| | success = run_pipeline(config, args.task_id) |
| | sys.exit(0 if success else 1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|