Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import ast | |
| from itertools import chain | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| import json | |
| import hashlib | |
| import editdistance | |
| from argparse import Namespace | |
| import numpy as np | |
| import torch | |
| from fairseq import checkpoint_utils, options, tasks, utils, distributed_utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.logging import progress_bar | |
| from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
| from fairseq.models import FairseqLanguageModel | |
| from omegaconf import DictConfig | |
| from pathlib import Path | |
| import hydra | |
| from hydra.core.config_store import ConfigStore | |
| from fairseq.dataclass.configs import ( | |
| CheckpointConfig, | |
| CommonConfig, | |
| CommonEvalConfig, | |
| DatasetConfig, | |
| DistributedTrainingConfig, | |
| GenerationConfig, | |
| FairseqDataclass, | |
| ) | |
| from dataclasses import dataclass, field, is_dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from omegaconf import OmegaConf | |
| logging.root.setLevel(logging.INFO) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| config_path = Path(__file__).resolve().parent / "conf" | |
| class OverrideConfig(FairseqDataclass): | |
| noise_wav: Optional[str] = field(default=None, metadata={'help': 'noise wav file'}) | |
| noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) | |
| noise_snr: float = field(default=0, metadata={'help': 'noise SNR in audio'}) | |
| modalities: List[str] = field(default_factory=lambda: [""], metadata={'help': 'which modality to use'}) | |
| data: Optional[str] = field(default=None, metadata={'help': 'path to test data directory'}) | |
| label_dir: Optional[str] = field(default=None, metadata={'help': 'path to test label directory'}) | |
| class InferConfig(FairseqDataclass): | |
| task: Any = None | |
| generation: GenerationConfig = GenerationConfig() | |
| common: CommonConfig = CommonConfig() | |
| common_eval: CommonEvalConfig = CommonEvalConfig() | |
| checkpoint: CheckpointConfig = CheckpointConfig() | |
| distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() | |
| dataset: DatasetConfig = DatasetConfig() | |
| override: OverrideConfig = OverrideConfig() | |
| is_ax: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" | |
| }, | |
| ) | |
| def main(cfg: DictConfig): | |
| if isinstance(cfg, Namespace): | |
| cfg = convert_namespace_to_omegaconf(cfg) | |
| assert cfg.common_eval.path is not None, "--path required for recognition!" | |
| assert ( | |
| not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam | |
| ), "--sampling requires --nbest to be equal to --beam" | |
| if cfg.common_eval.results_path is not None: | |
| os.makedirs(cfg.common_eval.results_path, exist_ok=True) | |
| output_path = os.path.join(cfg.common_eval.results_path, "decode.log") | |
| with open(output_path, "w", buffering=1, encoding="utf-8") as h: | |
| return _main(cfg, h) | |
| return _main(cfg, sys.stdout) | |
| def get_symbols_to_strip_from_output(generator): | |
| if hasattr(generator, "symbols_to_strip_from_output"): | |
| return generator.symbols_to_strip_from_output | |
| else: | |
| return {generator.eos, generator.pad} | |
| def _main(cfg, output_file): | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=output_file, | |
| ) | |
| logger = logging.getLogger("hybrid.speech_recognize") | |
| if output_file is not sys.stdout: # also print to stdout | |
| logger.addHandler(logging.StreamHandler(sys.stdout)) | |
| utils.import_user_module(cfg.common) | |
| models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path]) | |
| models = [model.eval().cuda() for model in models] #!! | |
| saved_cfg.task.modalities = cfg.override.modalities | |
| task = tasks.setup_task(saved_cfg.task) | |
| task.build_tokenizer(saved_cfg.tokenizer) | |
| task.build_bpe(saved_cfg.bpe) | |
| logger.info(cfg) | |
| # Fix seed for stochastic decoding | |
| if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
| np.random.seed(cfg.common.seed) | |
| utils.set_torch_seed(cfg.common.seed) | |
| use_cuda = torch.cuda.is_available() | |
| # Set dictionary | |
| dictionary = task.target_dictionary | |
| # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config | |
| task.cfg.noise_prob = cfg.override.noise_prob | |
| task.cfg.noise_snr = cfg.override.noise_snr | |
| task.cfg.noise_wav = cfg.override.noise_wav | |
| if cfg.override.data is not None: | |
| task.cfg.data = cfg.override.data | |
| if cfg.override.label_dir is not None: | |
| task.cfg.label_dir = cfg.override.label_dir | |
| task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) | |
| lms = [None] | |
| # Optimize ensemble for generation | |
| for model in chain(models, lms): | |
| if model is None: | |
| continue | |
| if cfg.common.fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| # Load dataset (possibly sharded) | |
| itr = task.get_batch_iterator( | |
| dataset=task.dataset(cfg.dataset.gen_subset), | |
| max_tokens=cfg.dataset.max_tokens, | |
| max_sentences=cfg.dataset.batch_size, | |
| max_positions=utils.resolve_max_positions( | |
| task.max_positions(), *[m.max_positions() for m in models] | |
| ), | |
| ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, | |
| required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, | |
| seed=cfg.common.seed, | |
| num_shards=cfg.distributed_training.distributed_world_size, | |
| shard_id=cfg.distributed_training.distributed_rank, | |
| num_workers=cfg.dataset.num_workers, | |
| data_buffer_size=cfg.dataset.data_buffer_size, | |
| ).next_epoch_itr(shuffle=False) | |
| progress = progress_bar.progress_bar( | |
| itr, | |
| log_format=cfg.common.log_format, | |
| log_interval=cfg.common.log_interval, | |
| default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), | |
| ) | |
| # Initialize generator | |
| if cfg.generation.match_source_len: | |
| logger.warning( | |
| "The option match_source_len is not applicable to speech recognition. Ignoring it." | |
| ) | |
| gen_timer = StopwatchMeter() | |
| extra_gen_cls_kwargs = { | |
| "lm_model": lms[0], | |
| "lm_weight": cfg.generation.lm_weight, | |
| } | |
| cfg.generation.score_reference = False # | |
| save_attention_plot = cfg.generation.print_alignment is not None | |
| cfg.generation.print_alignment = None # | |
| generator = task.build_generator( | |
| models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
| ) | |
| def decode_fn(x): | |
| symbols_ignore = get_symbols_to_strip_from_output(generator) | |
| symbols_ignore.add(dictionary.pad()) | |
| if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'): | |
| return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore) | |
| chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore) | |
| words = " ".join("".join(chars.split()).replace('|', ' ').split()) | |
| return words | |
| num_sentences = 0 | |
| has_target = True | |
| wps_meter = TimeMeter() | |
| result_dict = {'utt_id': [], 'ref': [], 'hypo': []} | |
| for sample in progress: | |
| sample = utils.move_to_cuda(sample) if use_cuda else sample | |
| if "net_input" not in sample: | |
| continue | |
| prefix_tokens = None | |
| if cfg.generation.prefix_size > 0: | |
| prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] | |
| constraints = None | |
| if "constraints" in sample: | |
| constraints = sample["constraints"] | |
| gen_timer.start() | |
| hypos = task.inference_step( | |
| generator, | |
| models, | |
| sample, | |
| prefix_tokens=prefix_tokens, | |
| constraints=constraints, | |
| ) | |
| num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) | |
| gen_timer.stop(num_generated_tokens) | |
| for i in range(len(sample["id"])): | |
| result_dict['utt_id'].append(sample['utt_id'][i]) | |
| ref_sent = decode_fn(sample['target'][i].int().cpu()) | |
| result_dict['ref'].append(ref_sent) | |
| best_hypo = hypos[i][0]['tokens'].int().cpu() | |
| hypo_str = decode_fn(best_hypo) | |
| result_dict['hypo'].append(hypo_str) | |
| logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n") | |
| wps_meter.update(num_generated_tokens) | |
| progress.log({"wps": round(wps_meter.avg)}) | |
| num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() | |
| logger.info("NOTE: hypothesis and token scores are output in base 2") | |
| logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( | |
| num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) | |
| yaml_str = OmegaConf.to_yaml(cfg.generation) | |
| fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) | |
| fid = fid % 1000000 | |
| result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json" | |
| json.dump(result_dict, open(result_fn, 'w'), indent=4) | |
| n_err, n_total = 0, 0 | |
| assert len(result_dict['hypo']) == len(result_dict['ref']) | |
| for hypo, ref in zip(result_dict['hypo'], result_dict['ref']): | |
| hypo, ref = hypo.strip().split(), ref.strip().split() | |
| n_err += editdistance.eval(hypo, ref) | |
| n_total += len(ref) | |
| wer = 100 * n_err / n_total | |
| wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}" | |
| with open(wer_fn, "w") as fo: | |
| fo.write(f"WER: {wer}\n") | |
| fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n") | |
| fo.write(f"{yaml_str}") | |
| logger.info(f"WER: {wer}%") | |
| return | |
| def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: | |
| container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) | |
| cfg = OmegaConf.create(container) | |
| OmegaConf.set_struct(cfg, True) | |
| if cfg.common.reset_logging: | |
| reset_logging() | |
| wer = float("inf") | |
| try: | |
| if cfg.common.profile: | |
| with torch.cuda.profiler.profile(): | |
| with torch.autograd.profiler.emit_nvtx(): | |
| distributed_utils.call_main(cfg, main) | |
| else: | |
| distributed_utils.call_main(cfg, main) | |
| except BaseException as e: # pylint: disable=broad-except | |
| if not cfg.common.suppress_crashes: | |
| raise | |
| else: | |
| logger.error("Crashed! %s", str(e)) | |
| return | |
| def cli_main() -> None: | |
| try: | |
| from hydra._internal.utils import ( | |
| get_args, | |
| ) # pylint: disable=import-outside-toplevel | |
| cfg_name = get_args().config_name or "infer" | |
| except ImportError: | |
| logger.warning("Failed to get config name from hydra args") | |
| cfg_name = "infer" | |
| cs = ConfigStore.instance() | |
| cs.store(name=cfg_name, node=InferConfig) | |
| for k in InferConfig.__dataclass_fields__: | |
| if is_dataclass(InferConfig.__dataclass_fields__[k].type): | |
| v = InferConfig.__dataclass_fields__[k].default | |
| cs.store(name=k, node=v) | |
| hydra_main() # pylint: disable=no-value-for-parameter | |
| if __name__ == "__main__": | |
| cli_main() | |