| """ |
| Module: preprocess.py |
| |
| This module provides a preprocessing pipeline for single-cell RNA sequencing (scRNA-seq) data |
| stored in AnnData format. It includes functions for loading data, filtering cells and genes, |
| normalizing and scaling data, and saving processed results. The pipeline is designed to be |
| configurable via hyperparameters and supports various preprocessing steps such as mitochondrial |
| gene filtering, highly variable gene selection, and log transformation. |
| |
| Main Features: |
| - Load and preprocess scRNA-seq data in AnnData format. |
| - Filter cells and genes based on various criteria. |
| - Normalize, scale, and log-transform data. |
| - Save processed data and metadata to disk. |
| - Configurable via JSON-based hyperparameters. |
| |
| Dependencies: |
| - anndata, numpy, pandas, scanpy, scipy, sklearn |
| |
| Usage: |
| - Run this script as a standalone program with a configuration file specifying the hyperparameters. |
| - Import the `preprocess` function and call it with the data path, metadata path, and hyperparameters. |
| """ |
|
|
| import gc |
| import json |
| import os |
| import warnings |
| from argparse import ArgumentParser |
| from typing import Sequence, Optional, Union |
| from pathlib import Path |
|
|
| import anndata as ad |
| import numpy as np |
| import pandas as pd |
| import scanpy as sc |
| from anndata import ImplicitModificationWarning |
| import scipy.sparse as sp |
| from scipy.sparse import csr_matrix, issparse |
| from sklearn.utils import sparsefuncs, sparsefuncs_fast |
|
|
| from teddy.data_processing.utils.gene_mapping.gene_mapper import ( |
| map_mouse_human, |
| map_mouse_human2, |
| ) |
|
|
| |
| _HUMAN_MITO_ENSEMBL= { |
| "ENSG00000211459", "ENSG00000210082", |
| |
| "ENSG00000210049", "ENSG00000210077", "ENSG00000209082", |
| "ENSG00000210100", "ENSG00000210107", "ENSG00000210112", |
| "ENSG00000210119", "ENSG00000210122", "ENSG00000210116", |
| "ENSG00000210117", "ENSG00000210118", "ENSG00000210124", |
| "ENSG00000210126", "ENSG00000210134", "ENSG00000210135", |
| "ENSG00000210142", "ENSG00000210144", "ENSG00000210148", |
| "ENSG00000210150", "ENSG00000210155", "ENSG00000210196", |
| "ENSG00000210151", |
| |
| "ENSG00000198888", "ENSG00000198763", "ENSG00000198840", |
| "ENSG00000198886", "ENSG00000212907", "ENSG00000198786", |
| "ENSG00000198695", "ENSG00000198804", "ENSG00000198712", |
| "ENSG00000198938", "ENSG00000198899", "ENSG00000228253", |
| "ENSG00000198727", |
| } |
|
|
| _HUMAN_MITO_SYMBOLS = { |
| "MT-RNR1", "MT-RNR2", "MT-TF", "MT-TV", "MT-TL1", "MT-TI", "MT-TQ", |
| "MT-TM", "MT-TW", "MT-TA", "MT-TN", "MT-TC", "MT-TY", "MT-TD", "MT-TK", |
| "MT-TG", "MT-TR", "MT-TH", "MT-TS2", "MT-TL2", "MT-TT", "MT-TE", "MT-TP", |
| "MT-TS1", "MT-ND1", "MT-ND2", "MT-ND3", "MT-ND4", "MT-ND4L", "MT-ND5", |
| "MT-ND6", "MT-CO1", "MT-CO2", "MT-CO3", "MT-ATP6", "MT-ATP8", "MT-CYB", |
| } |
|
|
|
|
| def load_data_and_metadata(data_path: str, metadata_path: str): |
| """ |
| Load an AnnData h5ad file (data_processing) and a JSON file (metadata). |
| """ |
| data = ad.read_h5ad(data_path) |
| with open(metadata_path, "r") as f: |
| metadata = json.load(f) |
| return data, metadata |
|
|
|
|
| def set_raw_if_necessary(data: ad.AnnData): |
| """ |
| If data_processing.raw is None, checks if data_processing.X is integer for ~64 cells. |
| If so, set data_processing.raw = data_processing. Otherwise return None (skip). |
| """ |
| if data.raw is not None: |
| return data |
| |
| if 'counts' in data.layers: |
| X = data.layers['counts'] |
| |
| if isinstance(X, np.ndarray): |
| X_sample = X[:64] |
| elif issparse(X): |
| X_sample = X[:64].toarray() |
| |
| if np.all(np.equal(np.mod(X_sample, 1), 0)): |
| data.raw = ad.AnnData(X = data.layers['counts'], var = data.var.copy()) |
| return data |
| |
| X = data.X |
| |
| if isinstance(X, np.ndarray): |
| X_sample = X[:64] |
| elif issparse(X): |
| X_sample = X[:64].toarray() |
| |
| if np.all(np.equal(np.mod(X_sample, 1), 0)): |
| data.raw = data |
| return data |
| else: |
| print("No integer-valued matrix found") |
| return None |
|
|
|
|
|
|
|
|
| def initialize_processed_layer(data: ad.AnnData): |
| """ |
| If 'processed' layer is missing, copy from data_processing.raw.X |
| """ |
| if "processed" not in data.layers: |
| data.layers["processed"] = data.raw.X.astype("float32") |
| return data |
|
|
|
|
| |
| |
| |
| |
| |
| def filter_reference_id(data: ad.AnnData, hyperparameters: dict): |
| human_map = pd.read_csv("teddy/data_processing/utils/gene_mapping/data/human_mapping.txt", sep="\t") |
| mouse_map = pd.read_csv("teddy/data_processing/utils/gene_mapping/data/2407_mouse_gene_mapping.txt", sep="\t") |
| orthologs = pd.read_csv( |
| "teddy/data_processing/utils/gene_mapping/data/mouse_to_human_orthologs.one2one.txt", sep="\t" |
| ) |
|
|
| if hyperparameters.get("mouse_nonorthologs", False): |
| reference_id = map_mouse_human2( |
| data_frame=data.var, |
| query_column=None, |
| human_map_db=human_map, |
| mouse_map_db=mouse_map, |
| orthology_db=orthologs, |
| )["reference_id"] |
| else: |
| reference_id = map_mouse_human( |
| data_frame=data.var, |
| query_column=None, |
| human_map_db=human_map, |
| mouse_map_db=mouse_map, |
| orthology_db=orthologs, |
| )["reference_id"] |
|
|
| valid_mask = reference_id != "" |
| data = data[:, valid_mask].copy() |
| reference_id = reference_id[valid_mask].reset_index(drop=True) |
|
|
| if not isinstance(data.layers["processed"], np.ndarray): |
| corrected = data.layers["processed"].toarray() |
| else: |
| corrected = data.layers["processed"] |
|
|
| unique_ids = reference_id.unique() |
| vars_to_keep = [] |
| for rid in unique_ids: |
| repeated_idx = np.where(reference_id == rid)[0] |
| vars_to_keep.append(repeated_idx[0]) |
| if len(repeated_idx) > 1: |
| corrected[:, repeated_idx[0]] = corrected[:, repeated_idx].max(axis=1) |
|
|
| vars_to_keep = sorted(vars_to_keep) |
| corrected = corrected[:, vars_to_keep] |
| data = data[:, vars_to_keep] |
|
|
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", category=ImplicitModificationWarning) |
| data.layers["processed"] = csr_matrix(corrected) |
| data.var["reference_id"] = list(reference_id[vars_to_keep]) |
|
|
| gc.collect() |
| return data |
|
|
|
|
| |
| |
|
|
|
|
| def remove_assays(data: ad.AnnData, assays_to_remove: list): |
| """ |
| Removes observations from specified 'assay' categories if 'assay' is in data_processing.obs. |
| """ |
| data = data[~data.obs.assay.isin(assays_to_remove)].copy() |
| gc.collect() |
| return data |
|
|
|
|
| def filter_cells_by_gene_counts(data: ad.AnnData, min_count: int): |
| """ |
| Removes cells (observations) whose total gene counts < min_count. |
| """ |
| mask = sc.pp.filter_cells(data.layers["processed"], min_counts=min_count)[0] |
| data = data[np.where(mask)].copy() |
| del mask |
| gc.collect() |
| return data |
|
|
|
|
| def filter_cells_by_mitochondrial_fraction(data: ad.AnnData, max_mito_prop: float): |
| """ |
| Remove low-quality cells whose mitochondrial read fraction exceeds *max_fraction*. |
| DO NOT RUN THIS IN ANY PREPROCESSING PIPELINE UNTIL YOU HAVE SET RAW COUNTS |
| Parameters |
| ---------- |
| data |
| `AnnData` object containing counts. Works with dense or sparse matrices. |
| max_mito_prop |
| Threshold above which cells are discarded. |
| Returns |
| ------- |
| AnnData |
| A **copy** of `data` with poor-quality cells removed and two new |
| columns added to ``.obs``: |
| - **mito_prop** – per-cell mitochondrial fraction |
| - **poor_quality_mito** – boolean flag marking dropped cells |
| """ |
| |
| |
| counts = data.X |
| var_index = data.var_names |
| if var_index[0].startswith("ENSG"): |
| ref = _HUMAN_MITO_ENSEMBL |
| else: |
| ref = _HUMAN_MITO_SYMBOLS |
| mito_idx = np.flatnonzero(var_index.isin(ref)) |
| if mito_idx.size == 0: |
| _logger.info("No mitochondrial genes found, returning data") |
| return data |
| if sp.issparse(counts): |
| total = counts.sum(axis=1).A1 |
| mito = counts[:, mito_idx].sum(axis=1).A1 |
| else: |
| total = counts.sum(axis=1) |
| mito = counts[:, mito_idx].sum(axis=1) |
| mito_prop = mito / np.maximum(total, 1) |
| data.obs["mito_prop"] = mito_prop |
| data.obs["poor_quality_mito"] = mito_prop > max_mito_prop |
| filtered = data[~data.obs["poor_quality_mito"]].copy() |
| gc.collect() |
| return filtered |
|
|
|
|
| def filter_highly_variable_genes(data: ad.AnnData, method: str): |
| """ |
| Filter genes to those that are highly variable using scanpy. |
| method must be "seurat_v3" or "cell_ranger". |
| """ |
| if "highly_variable" in data.var: |
| data = data[:, data.var["highly_variable"]] |
| else: |
| sc.pp.highly_variable_genes(data, flavor=method, n_top_genes=10000) |
| gc.collect() |
| return data |
|
|
|
|
| def normalize_data_inplace(matrix_csr: csr_matrix, norm_value: float): |
| """ |
| In-place row normalization + scale. matrix_csr must be a CSR matrix. |
| """ |
| |
| sparsefuncs_fast.inplace_csr_row_normalize_l1(matrix_csr) |
| |
| scale_factors = np.array([norm_value] * matrix_csr.shape[0]) |
| sparsefuncs.inplace_row_scale(matrix_csr, scale_factors) |
| gc.collect() |
|
|
|
|
| def scale_columns_by_median_dict(layer: csr_matrix, data: ad.AnnData, median_dict_path: str, median_column: str): |
| """ |
| Read a JSON median_dict, scale columns by 1/median. The lookup key is either |
| data_processing.var.index or data_processing.var[median_column]. |
| """ |
| with open(median_dict_path) as f: |
| median_dict = json.load(f) |
|
|
| if median_column == "index": |
| median_var = data.var.index |
| else: |
| median_var = data.var[median_column] |
|
|
| factors = [] |
| for g in median_var: |
| if g in median_dict: |
| factors.append(1.0 / median_dict[g]) |
| else: |
| factors.append(1.0) |
| factors = np.array(factors) |
|
|
| |
| sparsefuncs.inplace_csr_column_scale(layer, factors) |
|
|
|
|
| def log_transform_layer(data: ad.AnnData, layer_name: str = "processed"): |
| """ |
| Apply sc.pp.log1p in place to data_processing.layers[layer_name]. |
| """ |
| sc.pp.log1p(data, layer=layer_name, copy=False) |
|
|
|
|
| def compute_and_save_medians(data: ad.AnnData, data_path: str, hyperparameters: dict): |
| """ |
| Convert zeros to NaN, compute column medians ignoring NaN, and save results as JSON. |
| """ |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") |
|
|
| mat = data.layers["processed"].toarray() |
| mat[mat == 0] = np.nan |
| medians = np.nanmedian(mat, axis=0) |
|
|
| if hyperparameters["median_column"] == "index": |
| median_var = data.var.index.copy() |
| if not isinstance(median_var, pd.Series): |
| median_var = pd.Series(median_var) |
| else: |
| median_var = data.var[hyperparameters["median_column"]].copy() |
|
|
| valid_idxs = np.where(~np.isnan(medians))[0] |
| median_values = {median_var.iloc[k]: medians[k].item() for k in valid_idxs} |
|
|
| save_path = data_path.replace(hyperparameters["load_dir"], hyperparameters["save_dir"]) |
| save_path = save_path.replace(".h5ad", "_medians.json") |
| with open(save_path, "w") as f: |
| json.dump(median_values, f, indent=4) |
|
|
|
|
| def update_metadata(metadata: dict, data: ad.AnnData, hyperparameters: dict): |
| """ |
| Update metadata with cell_count and track processing arguments. |
| """ |
| metadata["cell_count"] = data.n_obs |
| if "processing_args" in metadata: |
| metadata["processing_args"] = [metadata["processing_args"]] + [hyperparameters] |
| else: |
| |
| metadata["processings_args"] = [hyperparameters] |
| return metadata |
|
|
|
|
| def save_and_cleanup(data: ad.AnnData, metadata: dict, data_path: str, metadata_path: str, hyperparameters: dict): |
| """ |
| Write processed data_processing and metadata to disk, then GC cleanup. |
| """ |
| load_dir = hyperparameters["load_dir"] |
| save_dir = hyperparameters["save_dir"] |
| data_filename = os.path.basename(data_path) |
| metadata_filename = os.path.basename(metadata_path) |
|
|
| save_processed_path = os.path.join(save_dir, data_filename) |
| save_metadata_path = os.path.join(save_dir, metadata_filename) |
|
|
| |
| os.makedirs(os.path.dirname(save_processed_path), exist_ok=True) |
| os.makedirs(os.path.dirname(save_metadata_path), exist_ok=True) |
|
|
| if data.n_obs == 0: |
| return None, None |
|
|
| |
| if not isinstance(data.raw.X, csr_matrix): |
| data.raw.X = csr_matrix(data.raw.X) |
| if not isinstance(data.X, csr_matrix): |
| data.X = csr_matrix(data.X) |
| if "processed" in data.layers and not isinstance(data.layers["processed"], csr_matrix): |
| data.layers["processed"] = csr_matrix(data.layers["processed"]) |
|
|
| try: |
| data.write_h5ad(save_processed_path, compression="gzip") |
| except Exception: |
| |
| if data.obs.index.name in data.obs.columns: |
| del data.obs[data.obs.index.name] |
| data.write_h5ad(save_processed_path, compression="gzip") |
|
|
| del data |
| gc.collect() |
|
|
| with open(save_metadata_path, "w") as f: |
| json.dump(metadata, f, indent=4) |
|
|
| return True, True |
|
|
|
|
| def preprocess(data_path: str, metadata_path: str, hyperparameters: dict): |
| """ |
| Original pipeline steps: |
| 1. Load data_processing & metadata |
| 2. Ensure data_processing.raw if counts are integer |
| 3. Initialize 'processed' layer |
| 4. Filter genes by reference_id |
| 5. Remove assays |
| 6. Filter cells (min gene counts) |
| 7. Filter cells (max mito fraction) |
| 8. HVG filtering |
| 9. Normalize total |
| 10. Median-based column scaling |
| 11. Log transform |
| 12. Compute medians (optional) |
| 13. Update metadata and save |
| """ |
| |
| data, metadata = load_data_and_metadata(data_path, metadata_path) |
|
|
| |
| data = set_raw_if_necessary(data) |
| if data is None: |
| return None, None |
|
|
| |
| data = initialize_processed_layer(data) |
| |
|
|
| |
| if hyperparameters["reference_id_only"]: |
| data = filter_reference_id(data, hyperparameters) |
|
|
| |
| if "assay" in data.obs and hyperparameters["remove_assays"]: |
| data = remove_assays(data, hyperparameters["remove_assays"]) |
|
|
| |
| if hyperparameters["min_gene_counts"]: |
| data = filter_cells_by_gene_counts(data, hyperparameters["min_gene_counts"]) |
|
|
| |
| if hyperparameters["max_mitochondrial_prop"]: |
| |
| data = filter_cells_by_mitochondrial_fraction( |
| data, hyperparameters["max_mitochondrial_prop"]) |
|
|
| |
| if hyperparameters["hvg_method"] in ["seurat_v3", "cell_ranger"]: |
| data = filter_highly_variable_genes(data, hyperparameters["hvg_method"]) |
|
|
| |
| if hyperparameters["normalized_total"]: |
| if not isinstance(data.layers["processed"], csr_matrix): |
| data.layers["processed"] = csr_matrix(data.layers["processed"]) |
| normalize_data_inplace(data.layers["processed"], hyperparameters["normalized_total"]) |
|
|
| |
| if hyperparameters["median_dict"]: |
| scale_columns_by_median_dict( |
| data.layers["processed"], data, hyperparameters["median_dict"], hyperparameters["median_column"] |
| ) |
|
|
| |
| if hyperparameters["log1p"]: |
| log_transform_layer(data, "processed") |
|
|
| |
| if hyperparameters["compute_medians"]: |
| compute_and_save_medians(data, data_path, hyperparameters) |
|
|
| |
| metadata = update_metadata(metadata, data, hyperparameters) |
| return save_and_cleanup(data, metadata, data_path, metadata_path, hyperparameters) |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| parser = ArgumentParser(description="Preprocess scRNA-seq data stored in AnnData format.") |
| parser.add_argument( |
| "--data_path", |
| type=str, |
| required=True, |
| help="Path to the input .h5ad file." |
| ) |
| parser.add_argument( |
| "--metadata_path", |
| type=str, |
| required=True, |
| help="Path to the input metadata JSON file." |
| ) |
| parser.add_argument( |
| "--config_path", |
| type=str, |
| required=True, |
| help="Path to the JSON configuration file containing hyperparameters." |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| with open(args.config_path, "r") as f: |
| hyperparameters = json.load(f) |
|
|
| |
| success, _ = preprocess( |
| data_path=args.data_path, |
| metadata_path=args.metadata_path, |
| hyperparameters=hyperparameters |
| ) |
|
|
| if success: |
| print("Preprocessing completed successfully.") |
| else: |
| print("Preprocessing returned no data (0 cells), no file saved.") |
|
|