| import warnings |
|
|
| import torch |
| from rdkit import Chem |
| from rdkit.Chem import Draw, AllChem |
| from rdkit.Chem import SanitizeFlags |
| from src.analysis.metrics import check_mol |
| from src import utils |
| from src.data.molecule_builder import build_molecule |
| from src.data.misc import protein_letters_1to3 |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None, |
| atom_decoder=None, aa_decoder=None, residue_decoder=None, |
| aa_atom_index=None): |
|
|
| rdpockets = [] |
| for i in torch.unique(pocket['mask']): |
|
|
| node_coord = pocket['x'][pocket['mask'] == i] |
| h = pocket['one_hot'][pocket['mask'] == i] |
| atom_mask = pocket['atom_mask'][pocket['mask'] == i] |
|
|
| pdb_infos = [] |
|
|
| if pocket_representation == 'side_chain_bead': |
| coord = node_coord |
|
|
| node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)] |
| atom_types = ['C' if r == 'CA' else 'F' for r in node_types] |
|
|
| elif pocket_representation == 'CA+': |
| aa_types = [aa_decoder[b] for b in h.argmax(-1)] |
| side_chain_vec = pocket['v'][pocket['mask'] == i] |
|
|
| coord = [] |
| atom_types = [] |
| for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)): |
|
|
| |
| for atom_name, idx in aa_atom_index[aa].items(): |
|
|
| if ~am[idx]: |
| warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}") |
| continue |
|
|
| coord.append(xyz + vec[idx]) |
| atom_types.append(atom_name[0]) |
|
|
| info = Chem.AtomPDBResidueInfo() |
| |
| info.SetResidueName(protein_letters_1to3[aa]) |
| info.SetResidueNumber(resi + 1) |
| info.SetOccupancy(1.0) |
| info.SetTempFactor(0.0) |
| info.SetName(f' {atom_name:<3}') |
| pdb_infos.append(info) |
|
|
| coord = torch.stack(coord, dim=0) |
|
|
| else: |
| raise NotImplementedError(f"{pocket_representation} residue representation not supported") |
|
|
| atom_types = torch.tensor([atom_encoder[a] for a in atom_types]) |
| rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder) |
|
|
| if len(pdb_infos) == len(rdmol.GetAtoms()): |
| for a, info in zip(rdmol.GetAtoms(), pdb_infos): |
| a.SetPDBResidueInfo(info) |
|
|
| rdpockets.append(rdmol) |
|
|
| return rdpockets |
|
|
|
|
| def mols_to_pdbfile(rdmols, filename, flavor=0): |
| pdb_str = "" |
| for i, mol in enumerate(rdmols): |
| pdb_str += f"MODEL{i + 1:>9}\n" |
| block = Chem.MolToPDBBlock(mol, flavor=flavor) |
| block = "\n".join(block.split("\n")[:-2]) |
| pdb_str += block + "\n" |
| pdb_str += f"ENDMDL\n" |
| pdb_str += f"END\n" |
|
|
| with open(filename, 'w') as f: |
| f.write(pdb_str) |
|
|
| return pdb_str |
|
|
|
|
| def mol_as_pdb(rdmol, filename=None, bfactor=None): |
|
|
| _rdmol = Chem.Mol(rdmol) |
| for a in _rdmol.GetAtoms(): |
| a.SetIsAromatic(False) |
| for b in _rdmol.GetBonds(): |
| b.SetIsAromatic(False) |
|
|
| if bfactor is not None: |
| for a in _rdmol.GetAtoms(): |
| val = a.GetPropsAsDict()[bfactor] |
|
|
| info = Chem.AtomPDBResidueInfo() |
| info.SetResidueName('UNL') |
| info.SetResidueNumber(1) |
| info.SetName(f' {a.GetSymbol():<3}') |
| info.SetIsHeteroAtom(True) |
| info.SetOccupancy(1.0) |
| info.SetTempFactor(val) |
| a.SetPDBResidueInfo(info) |
|
|
| pdb_str = Chem.MolToPDBBlock(_rdmol) |
|
|
| if filename is not None: |
| with open(filename, 'w') as f: |
| f.write(pdb_str) |
|
|
| return pdb_str |
|
|
|
|
| def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200), |
| label=check_mol, |
| highlight_atom=lambda atom: False, |
| highlight_bond=lambda bond: False): |
|
|
| draw_mols = [] |
| marked_atoms = [] |
| marked_bonds = [] |
| for mol in molecules: |
| draw_mol = Chem.Mol(mol) |
| Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE) |
| AllChem.Compute2DCoords(draw_mol) |
| draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol, |
| kekulize=False) |
| draw_mols.append(draw_mol) |
| marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)]) |
| marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)]) |
|
|
| drawOptions = Draw.rdMolDraw2D.MolDrawOptions() |
| drawOptions.prepareMolsBeforeDrawing = False |
| drawOptions.highlightBondWidthMultiplier = 20 |
|
|
| return Draw.MolsToGridImage(draw_mols, |
| molsPerRow=mols_per_row, |
| subImgSize=fig_size, |
| drawOptions=drawOptions, |
| highlightAtomLists=marked_atoms, |
| highlightBondLists=marked_bonds, |
| legends=[f'[{i}] {label(mol)}' for |
| i, mol in enumerate(draw_mols)]) |
|
|