Skip to content

Commit

Permalink
add esm2 batch scoring function; revise docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
FridrichMethod committed Nov 13, 2024
1 parent 5818a06 commit 6a0f6cc
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 190 deletions.
30 changes: 18 additions & 12 deletions examples/esm2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
"\n",
"from models.score_esm2 import score_complex\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# esm_3B_model = EsmForMaskedLM.from_pretrained(\"facebook/esm2_t36_3B_UR50D\")\n",
"# esm_3B_model = esm_3B_model.eval().cuda()\n",
"# esm_3B_tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t36_3B_UR50D\")\n",
"# esm_3B_model = esm_3B_model.eval().to(device)\n",
"esm_650M_model = EsmForMaskedLM.from_pretrained(\"facebook/esm2_t33_650M_UR50D\")\n",
"esm_650M_model = esm_650M_model.eval().cuda()\n",
"esm_650M_tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t33_650M_UR50D\")\n",
"\n",
"AA_ALPHABET = \"ACDEFGHIKLMNPQRSTVWY\"\n",
"AA_DICT = {aa: i for i, aa in enumerate(AA_ALPHABET)}\n",
"CHAIN_ALPHABET = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\""
"esm_650M_model = esm_650M_model.eval().to(device)"
]
},
{
Expand All @@ -35,20 +33,28 @@
"metadata": {},
"outputs": [],
"source": [
"# seqs = \"PSRLEEELRRRLTEP\"\n",
"# seqs = \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS\"\n",
"seqs = \"SGEVQLQESGGGLVQPGGSLRLSATASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS:PSRLEEELRRRLTEP\"\n",
"# seqs_list = [\n",
"# \"GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS:PSRLEEELRRRLTEP\",\n",
"# \"ARPGASVNELARPGASVKMSGHIDYAKMSCKASGYTFTSQAPGLEWVSAITWNELKASGYFTSQAPLQMLYLAVYYCAKPYYGSHVWGAVSAITWGVQLYAVAKYSRDNSKNTTVTVGTTVTVS:PSRLEEELRRRLTEP\",\n",
"# \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS:PSRLEEELRRRLTEP\",\n",
"# ]\n",
"seqs_list = [\n",
" \"PGLRAEDTAVYYCAKYPYELARPGYTFTSQAPGKGLGSHWYFDVWWYFDLYQMNSLRATIRDNSKNTWVSEVWGAGTASKMSCKASGGSVKMEDTAVYYCAKYPYYGSHGAGTDNSKNTVVTVS\",\n",
" \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS\",\n",
" \"SGEVQLQESGGGLVQPGGSLRLSATASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYAHVLEDRVDSFHDYWGQGTQVTVSS\",\n",
" \"TGMVTLDETGGGAVAPGGSLTLGMRATGVTISALNAMALGWYRQQPGLRAVLVAAVSERGNAMYREDVLGRYRVTRDAATRQVSLVMLQLLPSDTATYYGHVLEDRVDSFHDYWGEGTQVQVVP\",\n",
"]\n",
"\n",
"# entropy, loss, perplexity = score_complex(esm_3B_model, esm_3B_tokenizer, seqs, verbose=True)\n",
"# entropy, loss, perplexity = score_complex(esm_3B_model, esm_3B_tokenizer, seqs_list, verbose=True)\n",
"entropy, loss, perplexity = score_complex(\n",
" esm_650M_model, esm_650M_tokenizer, seqs, verbose=True\n",
" esm_650M_model, esm_650M_tokenizer, seqs_list, verbose=True\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "prodes",
"display_name": "pdmodels",
"language": "python",
"name": "python3"
},
Expand Down
65 changes: 32 additions & 33 deletions examples/esmif.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"\n",
"sys.path.append(\"..\")\n",
"\n",
"import esm\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
Expand All @@ -19,16 +20,13 @@
"from torch_geometric.nn import MessagePassing\n",
"from tqdm.auto import tqdm\n",
"\n",
"import esm\n",
"from models.score_esmif import score_complex\n",
"from models.sample_esmif import sample_complex\n",
"from models.score_esmif import score_complex\n",
"\n",
"esmif_model, esmif_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()\n",
"esmif_model = esmif_model.eval().cuda()\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"AA_ALPHABET = \"ACDEFGHIKLMNPQRSTVWY\"\n",
"AA_DICT = {aa: i for i, aa in enumerate(AA_ALPHABET)}\n",
"CHAIN_ALPHABET = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\""
"esmif_model, esmif_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()\n",
"esmif_model = esmif_model.eval().to(device)"
]
},
{
Expand All @@ -38,18 +36,19 @@
"outputs": [],
"source": [
"pdbfile = \"../pdbs/NbALFA_ALFAtag_AF3.pdb\"\n",
"target_seq_list = [\n",
" \"GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS\",\n",
" \"KGQVQLQQSAELALARMSCKASYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFSGHIDYADSVKGRFTIPGASVKMSGTEKMSCTAVYYCAKYPGQVQLQQSAELAASS\",\n",
" \"ARPGASVNELARPGASVKMSGHIDYAKMSCKASGYTFTSQAPGLEWVSAITWNELKASGYFTSQAPLQMLYLAVYYCAKPYYGSHVWGAVSAITWGVQLYAVAKYSRDNSKNTTVTVGTTVTVS\",\n",
" \"PGLRAEDTAVYYCAKYPYELARPGYTFTSQAPGKGLGSHWYFDVWWYFDLYQMNSLRATIRDNSKNTWVSEVWGAGTASKMSCKASGGSVKMEDTAVYYCAKYPYYGSHGAGTDNSKNTVVTVS\",\n",
" \"ASVRPGLYLQMNSGQVQLQQSALQQSAELYYGSHWYFDVWGAGTTVHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTDVWGAGTTVTEWVNSLRAEDARPGASVKMSGTDSVKGRFTISS\",\n",
" \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS\",\n",
"]\n",
"entropy, loss, perplexity = score_complex(\n",
" esmif_model,\n",
" esmif_alphabet,\n",
" pdbfile,\n",
" target_seq_list=[\n",
" \"GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS\",\n",
" \"KGQVQLQQSAELALARMSCKASYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFSGHIDYADSVKGRFTIPGASVKMSGTEKMSCTAVYYCAKYPGQVQLQQSAELAASS\",\n",
" \"ARPGASVNELARPGASVKMSGHIDYAKMSCKASGYTFTSQAPGLEWVSAITWNELKASGYFTSQAPLQMLYLAVYYCAKPYYGSHVWGAVSAITWGVQLYAVAKYSRDNSKNTTVTVGTTVTVS\",\n",
" \"PGLRAEDTAVYYCAKYPYELARPGYTFTSQAPGKGLGSHWYFDVWWYFDLYQMNSLRATIRDNSKNTWVSEVWGAGTASKMSCKASGGSVKMEDTAVYYCAKYPYYGSHGAGTDNSKNTVVTVS\",\n",
" \"ASVRPGLYLQMNSGQVQLQQSALQQSAELYYGSHWYFDVWGAGTTVHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTDVWGAGTTVTEWVNSLRAEDARPGASVKMSGTDSVKGRFTISS\",\n",
" \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS\",\n",
" ],\n",
" target_seq_list=target_seq_list,\n",
")\n",
"entropy.shape, loss, perplexity"
]
Expand All @@ -60,30 +59,30 @@
"metadata": {},
"outputs": [],
"source": [
"pdbfile = \"../pdbs/NbALFA_AF3.pdb\"\n",
"output_path = \"../results/sampled_monomer_esm.fasta\"\n",
"redisigned_residues = \"1 3 4 5 7 8 9 13 14 15 19 20 21 23 24 25 26 27 39 41 44 45 46 48 50 52 53 67 68 69 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 88 89 91 92 93 95 97 99 100 102 114 116 118 119 120 121 123 124\"\n",
"# pdbfile = \"../pdbs/NbALFA_AF3.pdb\"\n",
"# output_path = \"../results/sampled_monomer_esm.fasta\"\n",
"# redisigned_residues = \"1 3 4 5 7 8 9 13 14 15 19 20 21 23 24 25 26 27 39 41 44 45 46 48 50 52 53 67 68 69 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 88 89 91 92 93 95 97 99 100 102 114 116 118 119 120 121 123 124\"\n",
"\n",
"for batch_num in tqdm(range(2048)):\n",
" sample_complex(\n",
" esmif_model,\n",
" esmif_alphabet,\n",
" pdbfile,\n",
" output_path,\n",
" target_chain_id=\"A\",\n",
" batch_size=32,\n",
" redesigned_residues=redisigned_residues,\n",
" omit_aa=\"C\",\n",
" temperature=1.0,\n",
" padding_length=10,\n",
" index_offset=batch_num * 32,\n",
" )"
"# for batch_num in tqdm(range(2048)):\n",
"# sample_complex(\n",
"# esmif_model,\n",
"# esmif_alphabet,\n",
"# pdbfile,\n",
"# output_path,\n",
"# target_chain_id=\"A\",\n",
"# batch_size=32,\n",
"# redesigned_residues=redisigned_residues,\n",
"# omit_aa=\"C\",\n",
"# temperature=1.0,\n",
"# padding_length=10,\n",
"# index_offset=batch_num * 32,\n",
"# )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "prodes",
"display_name": "pdmodels",
"language": "python",
"name": "python3"
},
Expand Down
39 changes: 5 additions & 34 deletions examples/ligandmpnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,11 @@
" weights_only=True,\n",
")\n",
"ligand_mpnn = LigandMPNNBatch(\n",
" model_type=\"ligand_mpnn\",\n",
" k_neighbors=32,\n",
" atom_context_num=25,\n",
" ligand_mpnn_use_side_chain_context=True,\n",
" device=device,\n",
")\n",
"ligand_mpnn.load_state_dict(checkpoint[\"model_state_dict\"])\n",
"ligand_mpnn.to(device)\n",
"ligand_mpnn.eval()"
"ligand_mpnn.eval().to(device)"
]
},
{
Expand All @@ -43,7 +39,7 @@
"pdbfile = \"../pdbs/NbALFA_ALFAtag_AF3.pdb\"\n",
"chains_to_design = \"A\"\n",
"redesigned_residues = \"A1 A3 A4 A5 A7 A8 A9 A13 A14 A15 A19 A20 A21 A23 A24 A25 A26 A27 A39 A41 A44 A45 A46 A48 A50 A52 A53 A67 A68 A69 A72 A73 A74 A75 A76 A77 A78 A79 A80 A81 A82 A83 A84 A85 A86 A88 A89 A91 A92 A93 A95 A97 A99 A100 A102 A114 A116 A118 A119 A120 A121 A123 A124\"\n",
"target_seqs_list = [\n",
"seqs_list = [\n",
" \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS:PSRLEEELRRRLTEP\",\n",
" \"GGTVVLTESGGGTVAPGGSATLTATASGVTISALNAMAWGWYRQRPGERPVAVAAVSERGNAMYREDVRGRWTVTADRANKTVSLEMRDLQPEDTATYYPHVLEDRVDSFHDYWGAGVPLTVVP:PSRLEEELRRRLTEP\",\n",
" \"GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS:PSRLEEELRRRLTEP\",\n",
Expand All @@ -55,7 +51,7 @@
" pdbfile,\n",
" # chains_to_design=chains_to_design,\n",
" redesigned_residues=redesigned_residues,\n",
" seqs_list=target_seqs_list,\n",
" seqs_list=seqs_list,\n",
" use_side_chain_context=True,\n",
")\n",
"entropy.shape, loss, perplexity"
Expand All @@ -67,32 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"def extract_from_score(output_path):\n",
" with open(output_path, \"rb\") as f:\n",
" output = torch.load(f)\n",
"\n",
" entropy = -(\n",
" torch.tensor(output[\"logits\"][:, :, :20]).softmax(dim=-1).mean(dim=0).log()\n",
" ) # (L, 20)\n",
" target = torch.tensor(output[\"native_sequence\"], dtype=torch.long) # (L,)\n",
" loss = torch.gather(entropy, 1, target.unsqueeze(1)).squeeze() # (L,)\n",
" perplexity = torch.exp(loss.mean()).item() # scalar\n",
"\n",
" return entropy, loss, perplexity\n",
"\n",
"\n",
"def extract_from_sample(output_path):\n",
" with open(output_path, \"rb\") as f:\n",
" output = torch.load(f)\n",
"\n",
" entropy = -output[\"log_probs\"] # (B, L, 20)\n",
" target = output[\"generated_sequences\"] # (B, L)\n",
" loss = torch.gather(entropy, 2, target.unsqueeze(2)).squeeze(2) # (B, L)\n",
" perplexity = torch.exp(loss.mean(dim=-1)) # (B,)\n",
" # redesigned = output[\"chain_mask\"] == 1\n",
" # confidence = torch.exp(-loss[:, redesigned].mean(dim=-1))\n",
"\n",
" return entropy, loss, perplexity"
"from models.score_ligandmpnn import extract_from_sample, extract_from_score"
]
},
{
Expand Down Expand Up @@ -128,7 +99,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "prodes",
"display_name": "pdmodels",
"language": "python",
"name": "python3"
},
Expand Down
28 changes: 27 additions & 1 deletion models/sample_esmif.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,33 @@ def sample_complex(
temperature: float = 1.0,
padding_length: int = 10,
index_offset: int = 0,
):
) -> None:
"""Sample redesigned sequences for a given complex structure.
Args:
model (GVPTransformerModel):
Inverse folding model.
alphabet (Alphabet):
Alphabet object for encoding and decoding sequences.
pdbfile (str):
Path to the PDB file of the complex structure.
output_path (str):
Path to the output file.
target_chain_id (str):
Chain ID of the target sequence.
batch_size (int):
Number of sequences to sample.
redesigned_residues (str):
Residue positions to redesign, separated by spaces.
omit_aa (str):
Amino acids to omit from the output sequence.
temperature (float):
Temperature for sampling.
padding_length (int):
Length of padding between concatenated chains.
index_offset (int):
Offset for the sequence ID in the output file
"""

device = next(model.parameters()).device

Expand Down
Loading

0 comments on commit 6a0f6cc

Please sign in to comment.