| | |
| |
|
| | import numpy as np |
| | from omegaconf import OmegaConf |
| |
|
| | from utils.io import write_json |
| |
|
| |
|
| | def compute_recall(errors): |
| | num_elements = len(errors) |
| | sort_idx = np.argsort(errors) |
| | errors = np.array(errors.copy())[sort_idx] |
| | recall = (np.arange(num_elements) + 1) / num_elements |
| | recall = np.r_[0, recall] |
| | errors = np.r_[0, errors] |
| | return errors, recall |
| |
|
| |
|
| | def compute_auc(errors, recall, thresholds): |
| | aucs = [] |
| | for t in thresholds: |
| | last_index = np.searchsorted(errors, t, side="right") |
| | r = np.r_[recall[:last_index], recall[last_index - 1]] |
| | e = np.r_[errors[:last_index], t] |
| | auc = np.trapz(r, x=e) / t |
| | aucs.append(auc * 100) |
| | return aucs |
| |
|
| |
|
| | def write_dump(output_dir, experiment, cfg, results, metrics): |
| | dump = { |
| | "experiment": experiment, |
| | "cfg": OmegaConf.to_container(cfg), |
| | "results": results, |
| | "errors": {}, |
| | } |
| | for k, m in metrics.items(): |
| | if hasattr(m, "get_errors"): |
| | dump["errors"][k] = m.get_errors().numpy() |
| | write_json(output_dir / "log.json", dump) |
| |
|