| import io |
| from itertools import accumulate, chain |
| from copy import deepcopy |
| import random |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from rdkit import Chem |
| from torch_scatter import scatter_mean |
| from Bio.PDB import StructureBuilder, Chain, Model, Structure |
| from Bio.PDB.PICIO import read_PIC, write_PIC |
| from scipy.ndimage import gaussian_filter |
| from pdb import set_trace |
|
|
| from src.constants import FLOAT_TYPE, INT_TYPE |
| from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index |
| from src import utils |
| from src.data.misc import protein_letters_3to1, is_aa |
| from src.data.normal_modes import pdb_to_normal_modes |
| from src.data.nerf import get_nerf_params, ic_to_coords |
| import src.data.so3_utils as so3 |
|
|
|
|
| class TensorDict(dict): |
| def __init__(self, **kwargs): |
| super(TensorDict, self).__init__(**kwargs) |
|
|
| def _apply(self, func: str, *args, **kwargs): |
| """ Apply function to all tensors. """ |
| for k, v in self.items(): |
| if torch.is_tensor(v): |
| self[k] = getattr(v, func)(*args, **kwargs) |
| return self |
|
|
| |
| |
| |
| |
| |
|
|
| def cuda(self): |
| return self.to('cuda') |
|
|
| def cpu(self): |
| return self.to('cpu') |
| |
| def to(self, device): |
| return self._apply("to", device) |
| |
| def detach(self): |
| return self._apply("detach") |
|
|
| def __repr__(self): |
| def val_to_str(val): |
| if isinstance(val, torch.Tensor): |
| |
| |
| return "%r" % list(val.size()) |
| if isinstance(val, list): |
| return "[%r,]" % len(val) |
| else: |
| return "?" |
|
|
| return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})" |
|
|
|
|
| def collate_entity(batch): |
|
|
| out = {} |
| for prop in batch[0].keys(): |
|
|
| if prop == 'name': |
| out[prop] = [x[prop] for x in batch] |
|
|
| elif prop == 'size' or prop == 'n_bonds': |
| out[prop] = torch.tensor([x[prop] for x in batch]) |
|
|
| elif prop == 'bonds': |
| |
| offset = list(accumulate([x['size'] for x in batch], initial=0)) |
| out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1) |
|
|
| elif prop == 'residues': |
| out[prop] = list(chain.from_iterable(x[prop] for x in batch)) |
|
|
| elif prop in {'mask', 'bond_mask'}: |
| pass |
|
|
| else: |
| out[prop] = torch.cat([x[prop] for x in batch], dim=0) |
|
|
| |
| |
| if prop == 'x': |
| out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) |
| for i, x in enumerate(batch)], dim=0) |
| if prop == 'bond_one_hot': |
| |
| out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) |
| for i, x in enumerate(batch)], dim=0) |
|
|
| return out |
|
|
|
|
| def split_entity( |
| batch, |
| *, |
| index_types={'bonds'}, |
| edge_types={'bond_one_hot', 'bond_mask'}, |
| no_split={'name', 'size', 'n_bonds'}, |
| skip={'fragments'}, |
| batch_mask=None, |
| edge_mask=None |
| ): |
| """ Splits a batch into items and returns a list. """ |
|
|
| batch_mask = batch["mask"] if batch_mask is None else batch_mask |
| edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask |
| sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist() |
|
|
| batch_size = len(torch.unique(batch['mask'])) |
| out = {} |
| for prop in batch.keys(): |
| if prop in skip: |
| continue |
| if prop in no_split: |
| out[prop] = batch[prop] |
|
|
| elif prop in index_types: |
| offsets = list(accumulate(sizes[:-1], initial=0)) |
| out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets) |
|
|
| elif prop in edge_types: |
| out[prop] = utils.batch_to_list(batch[prop], edge_mask) |
|
|
| else: |
| out[prop] = utils.batch_to_list(batch[prop], batch_mask) |
|
|
| out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)] |
| return out |
|
|
|
|
| def repeat_items(batch, repeats): |
| batch_list = split_entity(batch) |
| out = collate_entity([x for _ in range(repeats) for x in batch_list]) |
| return type(batch)(**out) |
|
|
|
|
| def get_side_chain_bead_coord(biopython_residue): |
| """ |
| Places side chain bead at the location of the farthest side chain atom. |
| """ |
| if biopython_residue.get_resname() == 'GLY': |
| return None |
| if biopython_residue.get_resname() == 'ALA': |
| return biopython_residue['CB'].get_coord() |
|
|
| ca_coord = biopython_residue['CA'].get_coord() |
| side_chain_atoms = [a for a in biopython_residue.get_atoms() if |
| a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H'] |
| side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms]) |
|
|
| atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1)) |
|
|
| return side_chain_coords[atom_idx, :] |
|
|
|
|
| def get_side_chain_vectors(res, index_dict, size=None): |
| if size is None: |
| size = max([x for aa in index_dict.values() for x in aa.values()]) + 1 |
|
|
| resname = protein_letters_3to1[res.get_resname()] |
|
|
| out = np.zeros((size, 3)) |
| for atom in res.get_atoms(): |
| if atom.get_name() in index_dict[resname]: |
| idx = index_dict[resname][atom.get_name()] |
| out[idx] = atom.get_coord() - res['CA'].get_coord() |
| |
| |
| |
|
|
| return out |
|
|
|
|
| def get_normal_modes(res, normal_mode_dict): |
| nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] |
| return nm |
|
|
|
|
| def get_torsion_angles(res, device=None): |
| """ |
| Return the five chi angles. Missing angles are filled with zeros. |
| """ |
| ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] |
|
|
| ic_res = res.internal_coord |
| chi_angles = [ic_res.get_angle(chi) for chi in ANGLES] |
| chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles] |
|
|
| return torch.tensor(chi_angles, device=device) * np.pi / 180 |
|
|
|
|
| def apply_torsion_angles(res, chi_angles): |
| """ |
| Set side chain torsion angles of a biopython residue object with |
| internal coordinates. |
| """ |
| ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] |
|
|
| chi_angles = chi_angles * 180 / np.pi |
|
|
| |
|
|
| ic_res = res.internal_coord |
| for chi, angle in zip(ANGLES, chi_angles): |
| if ic_res.pick_angle(chi) is None: |
| continue |
| ic_res.bond_set(chi, angle) |
|
|
| res.parent.internal_to_atom_coordinates(verbose=False) |
| |
| |
|
|
| return res |
|
|
|
|
| def prepare_internal_coord(res): |
|
|
| |
| new_struct = Structure.Structure('X') |
| new_struct.header = {} |
| new_model = Model.Model(0) |
| new_struct.add(new_model) |
| new_chain = Chain.Chain('X') |
| new_model.add(new_chain) |
| new_chain.add(res) |
| res.set_parent(new_chain) |
|
|
| |
| new_chain.atom_to_internal_coordinates() |
|
|
| pic_io = io.StringIO() |
| write_PIC(new_struct, pic_io) |
| return pic_io.getvalue() |
|
|
|
|
| def residue_from_internal_coord(ic_string): |
| pic_io = io.StringIO(ic_string) |
| struct = read_PIC(pic_io, quick=True) |
| res = struct.child_list[0].child_list[0].child_list[0] |
| res.parent.internal_to_atom_coordinates(verbose=False) |
| return res |
|
|
|
|
| def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder, |
| residue_bond_encoder, pocket_representation='side_chain_bead', |
| compute_nerf_params=False, compute_bb_frames=False, |
| nma_input=None): |
|
|
| assert nma_input is None or pocket_representation == 'CA+', \ |
| "vector features are only supported for CA+ pockets" |
|
|
| |
| biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1])) |
|
|
| if nma_input is not None: |
| |
| if isinstance(nma_input, dict): |
| nma_dict = nma_input |
|
|
| |
| else: |
| nma_dict = pdb_to_normal_modes(str(nma_input)) |
|
|
| if pocket_representation == 'side_chain_bead': |
| ca_coords = np.zeros((len(biopython_residues), 3)) |
| ca_types = np.zeros(len(biopython_residues), dtype='int64') |
| side_chain_coords = [] |
| side_chain_aa_types = [] |
| edges = [] |
| edge_types = [] |
| last_res_id = None |
| for i, res in enumerate(biopython_residues): |
| aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] |
| ca_coords[i, :] = res['CA'].get_coord() |
| ca_types[i] = aa |
| side_chain_coord = get_side_chain_bead_coord(res) |
| if side_chain_coord is not None: |
| side_chain_coords.append(side_chain_coord) |
| side_chain_aa_types.append(aa) |
| edges.append((i, len(ca_coords) + len(side_chain_coords) - 1)) |
| edge_types.append(residue_bond_encoder['CA-SS']) |
|
|
| |
| if i > 0 and res.id[1] == last_res_id + 1: |
| edges.append((i - 1, i)) |
| edge_types.append(residue_bond_encoder['CA-CA']) |
|
|
| last_res_id = res.id[1] |
|
|
| |
| side_chain_coords = np.stack(side_chain_coords) |
| pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0) |
| pocket_coords = torch.from_numpy(pocket_coords) |
|
|
| |
| amino_acid_onehot = F.one_hot( |
| torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0), |
| num_classes=len(amino_acid_encoder) |
| ) |
| side_chain_onehot = np.concatenate([ |
| np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']), |
| [len(ca_coords), 1]), |
| np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']), |
| [len(side_chain_coords), 1]) |
| ], axis=0) |
| side_chain_onehot = torch.from_numpy(side_chain_onehot) |
| pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1) |
|
|
| vector_features = None |
| nma_features = None |
|
|
| |
| edges = torch.tensor(edges).T |
| edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder)) |
|
|
| elif pocket_representation == 'CA+': |
| ca_coords = np.zeros((len(biopython_residues), 3)) |
| ca_types = np.zeros(len(biopython_residues), dtype='int64') |
|
|
| v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 |
| vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32') |
| nf_nma = 5 |
| nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32') |
|
|
| edges = [] |
| edge_types = [] |
| last_res_id = None |
| for i, res in enumerate(biopython_residues): |
| aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] |
| ca_coords[i, :] = res['CA'].get_coord() |
| ca_types[i] = aa |
|
|
| vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim) |
| if nma_input is not None: |
| nma_feats[i] = get_normal_modes(res, nma_dict) |
|
|
| |
| if i > 0 and res.id[1] == last_res_id + 1: |
| edges.append((i - 1, i)) |
| edge_types.append(residue_bond_encoder['CA-CA']) |
|
|
| last_res_id = res.id[1] |
|
|
| |
| pocket_coords = torch.from_numpy(ca_coords) |
|
|
| |
| pocket_onehot = F.one_hot(torch.from_numpy(ca_types), |
| num_classes=len(amino_acid_encoder)) |
|
|
| vector_features = torch.from_numpy(vec_feats) |
| nma_features = torch.from_numpy(nma_feats) |
|
|
| |
| if len(edges) < 1: |
| edges = torch.empty(2, 0) |
| edge_types = torch.empty(0, len(residue_bond_encoder)) |
| else: |
| edges = torch.tensor(edges).T |
| edge_types = F.one_hot(torch.tensor(edge_types), |
| num_classes=len(residue_bond_encoder)) |
|
|
| else: |
| raise NotImplementedError( |
| f"Pocket representation '{pocket_representation}' not implemented") |
|
|
| |
|
|
| pocket = { |
| 'x': pocket_coords.to(dtype=FLOAT_TYPE), |
| 'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE), |
| |
| 'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE), |
| 'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE), |
| 'bonds': edges.to(INT_TYPE), |
| 'bond_one_hot': edge_types.to(FLOAT_TYPE), |
| 'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE), |
| 'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE), |
| } |
|
|
| if vector_features is not None: |
| pocket['v'] = vector_features.to(dtype=FLOAT_TYPE) |
|
|
| if nma_input is not None: |
| pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE) |
|
|
| if compute_nerf_params: |
| nerf_params = [get_nerf_params(r) for r in biopython_residues] |
| nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0) |
| for k in nerf_params[0].keys()} |
| pocket.update(nerf_params) |
|
|
| if compute_bb_frames: |
| n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues])) |
| ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues])) |
| c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues])) |
| pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz) |
|
|
| return pocket, biopython_residues |
|
|
|
|
| def encode_atom(rd_atom, atom_encoder): |
| element = rd_atom.GetSymbol().capitalize() |
|
|
| explicitHs = rd_atom.GetNumExplicitHs() |
| if explicitHs == 1 and f'{element}H' in atom_encoder: |
| return atom_encoder[f'{element}H'] |
|
|
| charge = rd_atom.GetFormalCharge() |
| if charge == 1 and f'{element}+' in atom_encoder: |
| return atom_encoder[f'{element}+'] |
| if charge == -1 and f'{element}-' in atom_encoder: |
| return atom_encoder[f'{element}-'] |
|
|
| return atom_encoder[element] |
|
|
|
|
| def prepare_ligand(rdmol, atom_encoder, bond_encoder): |
|
|
| |
| if 'H' not in atom_encoder: |
| rdmol = Chem.RemoveAllHs(rdmol, sanitize=False) |
|
|
| |
| ligand_coord = rdmol.GetConformer().GetPositions() |
| ligand_coord = torch.from_numpy(ligand_coord) |
|
|
| |
| ligand_onehot = F.one_hot( |
| torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]), |
| num_classes=len(atom_encoder) |
| ) |
|
|
| |
| adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND'] |
| for b in rdmol.GetBonds(): |
| i = b.GetBeginAtomIdx() |
| j = b.GetEndAtomIdx() |
| adj[i, j] = bond_encoder[str(b.GetBondType())] |
| adj[j, i] = adj[i, j] |
|
|
| |
| bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0) |
| |
| bond_types = adj[bonds[0], bonds[1]].astype('int64') |
| bonds = torch.from_numpy(bonds) |
| bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder)) |
|
|
| ligand = { |
| 'x': ligand_coord.to(dtype=FLOAT_TYPE), |
| 'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE), |
| 'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE), |
| 'bonds': bonds.to(INT_TYPE), |
| 'bond_one_hot': bond_types.to(FLOAT_TYPE), |
| 'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE), |
| 'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE), |
| 'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE), |
| } |
|
|
| return ligand |
|
|
|
|
| def process_raw_molecule_with_empty_pocket(rdmol): |
| ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) |
| pocket = { |
| 'x': torch.tensor([], dtype=FLOAT_TYPE), |
| 'one_hot': torch.tensor([], dtype=FLOAT_TYPE), |
| 'size': torch.tensor([], dtype=INT_TYPE), |
| 'mask': torch.tensor([], dtype=INT_TYPE), |
| 'bonds': torch.tensor([], dtype=INT_TYPE), |
| 'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE), |
| 'bond_mask': torch.tensor([], dtype=INT_TYPE), |
| 'n_bonds': torch.tensor([], dtype=INT_TYPE), |
| } |
| return ligand, pocket |
|
|
|
|
| def process_raw_pair(biopython_model, rdmol, dist_cutoff=None, |
| pocket_representation='side_chain_bead', |
| compute_nerf_params=False, compute_bb_frames=False, |
| nma_input=None, return_pocket_pdb=False): |
|
|
| |
| ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) |
|
|
| |
| pocket_residues = [] |
| for residue in biopython_model.get_residues(): |
|
|
| |
| if not is_aa(residue.get_resname(), standard=True): |
| continue |
|
|
| res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()])) |
| if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff: |
| pocket_residues.append(residue) |
|
|
| pocket, pocket_residues = prepare_pocket( |
| pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder, |
| pocket_representation, compute_nerf_params, compute_bb_frames, nma_input |
| ) |
|
|
| if return_pocket_pdb: |
| builder = StructureBuilder.StructureBuilder() |
| builder.init_structure("") |
| builder.init_model(0) |
| pocket_struct = builder.get_structure() |
| for residue in pocket_residues: |
| chain = residue.get_parent().get_id() |
|
|
| |
| if not pocket_struct[0].has_id(chain): |
| builder.init_chain(chain) |
|
|
| |
| pocket_struct[0][chain].add(residue) |
|
|
| pocket['pocket_pdb'] = pocket_struct |
| |
| |
|
|
| return ligand, pocket |
|
|
|
|
| class AppendVirtualNodes: |
| def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0): |
| self.max_size = max_ligand_size |
| self.atom_encoder = atom_encoder |
| self.bond_encoder = bond_encoder |
| self.vidx = atom_encoder['NOATOM'] |
| self.bidx = bond_encoder['NOBOND'] |
| self.scale = scale |
|
|
| def __call__(self, ligand, max_size=None, eps=1e-6): |
| if max_size is None: |
| max_size = self.max_size |
|
|
| n_virt = max_size - ligand['size'] |
|
|
| C = torch.cov(ligand['x'].T) |
| L = torch.linalg.cholesky(C + torch.eye(3) * eps) |
| mu = ligand['x'].mean(0, keepdim=True) |
| virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale |
|
|
| |
| virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) |
| virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) |
|
|
| ligand['x'] = torch.cat([ligand['x'], virt_coords]) |
| ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) |
| ligand['virtual_mask'] = virt_mask |
| ligand['size'] = max_size |
|
|
| |
| new_bonds = torch.triu_indices(max_size, max_size, offset=1) |
|
|
| bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx |
| row, col = ligand['bonds'] |
| bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) |
| new_row, new_col = new_bonds |
| bond_types = bond_types[new_row, new_col] |
|
|
| ligand['bonds'] = new_bonds |
| ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) |
| ligand['n_bonds'] = len(ligand['bond_one_hot']) |
|
|
| return ligand |
|
|
|
|
| class AppendVirtualNodesInCoM: |
| def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10): |
| self.atom_encoder = atom_encoder |
| self.bond_encoder = bond_encoder |
| self.vidx = atom_encoder['NOATOM'] |
| self.bidx = bond_encoder['NOBOND'] |
| self.add_min = add_min |
| self.add_max = add_max |
|
|
| def __call__(self, ligand): |
|
|
| n_virt = random.randint(self.add_min, self.add_max) |
|
|
| |
| virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1) |
|
|
| |
| virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) |
| virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) |
|
|
| ligand['x'] = torch.cat([ligand['x'], virt_coords]) |
| ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) |
| ligand['virtual_mask'] = virt_mask |
| ligand['size'] = len(ligand['x']) |
|
|
| |
| new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1) |
|
|
| bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx |
| row, col = ligand['bonds'] |
| bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) |
| new_row, new_col = new_bonds |
| bond_types = bond_types[new_row, new_col] |
|
|
| ligand['bonds'] = new_bonds |
| ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) |
| ligand['n_bonds'] = len(ligand['bond_one_hot']) |
|
|
| return ligand |
|
|
|
|
| def rdmol_to_smiles(rdmol): |
| mol = Chem.Mol(rdmol) |
| Chem.RemoveStereochemistry(mol) |
| mol = Chem.RemoveHs(mol) |
| return Chem.MolToSmiles(mol) |
|
|
|
|
| def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None): |
| |
| n_nodes_lig = [len(x) for x in lig_positions] |
| n_nodes_pocket = [len(x) for x in pocket_positions] |
|
|
| joint_histogram = np.zeros((np.max(n_nodes_lig) + 1, |
| np.max(n_nodes_pocket) + 1)) |
|
|
| for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket): |
| joint_histogram[nlig, npocket] += 1 |
|
|
| print(f'Original histogram: {np.count_nonzero(joint_histogram)}/' |
| f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled') |
|
|
| |
| if smooth_sigma is not None: |
| filtered_histogram = gaussian_filter( |
| joint_histogram, sigma=smooth_sigma, order=0, mode='constant', |
| cval=0.0, truncate=4.0) |
|
|
| print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/' |
| f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled') |
|
|
| joint_histogram = filtered_histogram |
|
|
| return joint_histogram |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def get_type_histogram(one_hot, type_encoder): |
|
|
| one_hot = np.concatenate(one_hot, axis=0) |
|
|
| decoder = list(type_encoder.keys()) |
| counts = {k: 0 for k in type_encoder.keys()} |
| for a in [decoder[x] for x in one_hot.argmax(1)]: |
| counts[a] += 1 |
|
|
| return counts |
|
|
|
|
| def get_residue_with_resi(pdb_chain, resi): |
| res = [x for x in pdb_chain.get_residues() if x.id[1] == resi] |
| assert len(res) == 1 |
| return res[0] |
|
|
|
|
| def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0): |
|
|
| if ligand.endswith(".sdf"): |
| |
| rdmol = Chem.SDMolSupplier(str(ligand))[0] |
| ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float() |
| resi = None |
| else: |
| |
| chain, resi = ligand.split(':') |
| ligand = get_residue_with_resi(pdb_model[chain], int(resi)) |
| ligand_coords = torch.from_numpy( |
| np.array([a.get_coord() for a in ligand.get_atoms()])) |
|
|
| pocket_residues = [] |
| for residue in pdb_model.get_residues(): |
| if residue.id[1] == resi: |
| continue |
|
|
| res_coords = torch.from_numpy( |
| np.array([a.get_coord() for a in residue.get_atoms()])) |
| if is_aa(residue.get_resname(), standard=True) \ |
| and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff: |
| pocket_residues.append(residue) |
|
|
| return pocket_residues |
|
|
|
|
| def encode_residues(biopython_residues, type_encoder, level='atom', |
| remove_H=True): |
| assert level in {'atom', 'residue'} |
|
|
| if level == 'atom': |
| entities = [a for res in biopython_residues for a in res.get_atoms() |
| if (a.element != 'H' or not remove_H)] |
| types = [a.element.capitalize() for a in entities] |
| else: |
| entities = [res['CA'] for res in biopython_residues] |
| types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues] |
|
|
| coord = torch.tensor(np.stack([e.get_coord() for e in entities])) |
| one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]), |
| num_classes=len(type_encoder)) |
|
|
| return coord, one_hot |
|
|
|
|
| def center_data(ligand, pocket): |
| if pocket['x'].numel() > 0: |
| pocket_com = pocket.center() |
| else: |
| pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) |
|
|
| ligand['x'] = ligand['x'] - pocket_com[ligand['mask']] |
| return ligand, pocket |
|
|
|
|
| def get_bb_transform(n_xyz, ca_xyz, c_xyz): |
| """ |
| Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with |
| Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame |
| |
| Args: |
| n_xyz: (n, 3) |
| ca_xyz: (n, 3) |
| c_xyz: (n, 3) |
| |
| Returns: |
| axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3) |
| translation vector of shape (n, 3) |
| """ |
|
|
| def rotation_matrix(angle, axis): |
| axis_mapping = {'x': 0, 'y': 1, 'z': 2} |
| axis = axis_mapping[axis] |
| vector = torch.zeros(len(angle), 3) |
| vector[:, axis] = 1 |
| |
| return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector) |
|
|
| translation = ca_xyz |
| n_xyz = n_xyz - translation |
| c_xyz = c_xyz - translation |
|
|
| |
|
|
| |
| theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0]) |
| Ry = rotation_matrix(theta_y, 'y') |
| Ry = Ry.transpose(2, 1) |
| n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz) |
|
|
| |
| theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0]) |
| Rz = rotation_matrix(theta_z, 'z') |
| Rz = Rz.transpose(2, 1) |
| |
|
|
| |
|
|
| |
| c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz) |
| theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1]) |
| Rx = rotation_matrix(theta_x, 'x') |
| Rx = Rx.transpose(2, 1) |
| |
|
|
| |
| Ry = Ry.transpose(2, 1) |
| Rz = Rz.transpose(2, 1) |
| Rx = Rx.transpose(2, 1) |
| R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx) |
|
|
| |
| |
| return so3.rotation_vector_from_matrix(R), translation |
|
|
|
|
| class Residues(TensorDict): |
| """ |
| Dictionary-like container for residues that supports some basic transformations. |
| """ |
|
|
| |
| KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord', |
| 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', |
| 'chi_indices', 'axis_angle', 'mask', 'bond_mask'} |
|
|
| |
| COORD_KEYS = {'x', 'fixed_coord'} |
|
|
| |
| VECTOR_KEYS = {'v', 'nma_vec'} |
|
|
| |
| MUTABLE_PROPS_SS_AND_BB = {'v'} |
|
|
| |
| MUTABLE_PROPS_SS = {'chi'} |
|
|
| |
| MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'} |
|
|
| |
| IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask', |
| 'atom_mask', 'nerf_indices', 'length', 'theta', |
| 'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'} |
|
|
| def copy(self): |
| data = super().copy() |
| return Residues(**data) |
|
|
| def deepcopy(self): |
| data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v) |
| for k, v in self.items()} |
| return Residues(**data) |
|
|
| def center(self): |
| com = scatter_mean(self['x'], self['mask'], dim=0) |
| self['x'] = self['x'] - com[self['mask']] |
| self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1) |
| return com |
|
|
| def set_empty_v(self): |
| self['v'] = torch.tensor([], device=self['x'].device) |
|
|
| @torch.no_grad() |
| def set_chi(self, chi_angles): |
| self['chi'][:, :5] = chi_angles |
| nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask', |
| 'nerf_indices', 'length', 'theta', |
| 'chi', 'ddihedral', 'chi_indices']} |
| self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1) |
|
|
| @torch.no_grad() |
| def set_frame(self, new_ca_coord, new_axis_angle): |
| bb_coord = self['fixed_coord'] |
| bb_coord = bb_coord - self['x'].unsqueeze(1) |
| rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle']) |
| rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle) |
| rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2) |
| bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord) |
| bb_coord = bb_coord + new_ca_coord.unsqueeze(1) |
|
|
| self['x'] = new_ca_coord |
| self['axis_angle'] = new_axis_angle |
| self['fixed_coord'] = bb_coord |
| self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v']) |
|
|
| @staticmethod |
| def empty(device): |
| return Residues( |
| x=torch.zeros(1, 3, device=device).float(), |
| mask=torch.zeros(1, 1, device=device).long(), |
| size=torch.zeros(1, device=device).long(), |
| ) |
|
|
|
|
| def randomize_tensors(tensor_dict, exclude_keys=None): |
| """Replace tensors with random tensors with the same shape.""" |
| exclude_keys = set() if exclude_keys is None else set(exclude_keys) |
| for k, v in tensor_dict.items(): |
| if isinstance(v, torch.Tensor) and k not in exclude_keys: |
| if torch.is_floating_point(v): |
| tensor_dict[k] = torch.randn_like(v) |
| else: |
| tensor_dict[k] = torch.randint_like(v, low=-42, high=42) |
| return tensor_dict |
|
|