-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8c2913f
commit 39618c4
Showing
19 changed files
with
670 additions
and
203 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"\n", | ||
"sys.path.append(\"..\")\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"from transformers import AutoTokenizer, EsmForMaskedLM\n", | ||
"\n", | ||
"from models.score_esm2 import score_complex\n", | ||
"\n", | ||
"# esm_3B_model = EsmForMaskedLM.from_pretrained(\"facebook/esm2_t36_3B_UR50D\")\n", | ||
"# esm_3B_model = esm_3B_model.eval().cuda()\n", | ||
"# esm_3B_tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t36_3B_UR50D\")\n", | ||
"esm_650M_model = EsmForMaskedLM.from_pretrained(\"facebook/esm2_t33_650M_UR50D\")\n", | ||
"esm_650M_model = esm_650M_model.eval().cuda()\n", | ||
"esm_650M_tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t33_650M_UR50D\")\n", | ||
"\n", | ||
"AA_ALPHABET = \"ACDEFGHIKLMNPQRSTVWY\"\n", | ||
"AA_DICT = {aa: i for i, aa in enumerate(AA_ALPHABET)}\n", | ||
"CHAIN_ALPHABET = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# seqs = \"PSRLEEELRRRLTEP\"\n", | ||
"# seqs = \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS\"\n", | ||
"seqs = \"SGEVQLQESGGGLVQPGGSLRLSATASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS:PSRLEEELRRRLTEP\"\n", | ||
"\n", | ||
"# entropy, loss, perplexity = score_complex(esm_3B_model, esm_3B_tokenizer, seqs, verbose=True)\n", | ||
"entropy, loss, perplexity = score_complex(\n", | ||
" esm_650M_model, esm_650M_tokenizer, seqs, verbose=True\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "prodes", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"\n", | ||
"sys.path.append(\"..\")\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"import torch\n", | ||
"import torch_geometric\n", | ||
"import torch_sparse\n", | ||
"from torch_geometric.nn import MessagePassing\n", | ||
"from tqdm.auto import tqdm\n", | ||
"\n", | ||
"import esm\n", | ||
"from models.score_esmif import score_complex\n", | ||
"from models.sample_esmif import sample_complex\n", | ||
"\n", | ||
"esmif_model, esmif_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()\n", | ||
"esmif_model = esmif_model.eval().cuda()\n", | ||
"\n", | ||
"AA_ALPHABET = \"ACDEFGHIKLMNPQRSTVWY\"\n", | ||
"AA_DICT = {aa: i for i, aa in enumerate(AA_ALPHABET)}\n", | ||
"CHAIN_ALPHABET = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pdbfile = \"../pdbs/NbALFA_ALFAtag_AF3.pdb\"\n", | ||
"entropy, loss, perplexity = score_complex(\n", | ||
" esmif_model,\n", | ||
" esmif_alphabet,\n", | ||
" pdbfile,\n", | ||
" target_seq_list=[\n", | ||
" \"GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS\",\n", | ||
" \"KGQVQLQQSAELALARMSCKASYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFSGHIDYADSVKGRFTIPGASVKMSGTEKMSCTAVYYCAKYPGQVQLQQSAELAASS\",\n", | ||
" \"ARPGASVNELARPGASVKMSGHIDYAKMSCKASGYTFTSQAPGLEWVSAITWNELKASGYFTSQAPLQMLYLAVYYCAKPYYGSHVWGAVSAITWGVQLYAVAKYSRDNSKNTTVTVGTTVTVS\",\n", | ||
" \"PGLRAEDTAVYYCAKYPYELARPGYTFTSQAPGKGLGSHWYFDVWWYFDLYQMNSLRATIRDNSKNTWVSEVWGAGTASKMSCKASGGSVKMEDTAVYYCAKYPYYGSHGAGTDNSKNTVVTVS\",\n", | ||
" \"ASVRPGLYLQMNSGQVQLQQSALQQSAELYYGSHWYFDVWGAGTTVHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTDVWGAGTTVTEWVNSLRAEDARPGASVKMSGTDSVKGRFTISS\",\n", | ||
" \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS\",\n", | ||
" ],\n", | ||
")\n", | ||
"entropy.shape, loss, perplexity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pdbfile = \"../pdbs/NbALFA_AF3.pdb\"\n", | ||
"output_path = \"../results/sampled_monomer_esm.fasta\"\n", | ||
"redisigned_residues = \"1 3 4 5 7 8 9 13 14 15 19 20 21 23 24 25 26 27 39 41 44 45 46 48 50 52 53 67 68 69 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 88 89 91 92 93 95 97 99 100 102 114 116 118 119 120 121 123 124\"\n", | ||
"\n", | ||
"for batch_num in tqdm(range(2048)):\n", | ||
" sample_complex(\n", | ||
" esmif_model,\n", | ||
" esmif_alphabet,\n", | ||
" pdbfile,\n", | ||
" output_path,\n", | ||
" target_chain_id=\"A\",\n", | ||
" batch_size=32,\n", | ||
" redesigned_residues=redisigned_residues,\n", | ||
" omit_aa=\"C\",\n", | ||
" temperature=1.0,\n", | ||
" padding_length=10,\n", | ||
" index_offset=batch_num * 32,\n", | ||
" )" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "prodes", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"\n", | ||
"sys.path.append(\"../\")\n", | ||
"\n", | ||
"from models.score_ligandmpnn import LigandMPNNBatch, score_complex\n", | ||
"\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"checkpoint = torch.load(\n", | ||
" \"../models/model_params/ligandmpnn_v_32_020_25.pt\", map_location=device\n", | ||
")\n", | ||
"ligand_mpnn = LigandMPNNBatch(\n", | ||
" model_type=\"ligand_mpnn\",\n", | ||
" k_neighbors=32,\n", | ||
" atom_context_num=25,\n", | ||
" ligand_mpnn_use_side_chain_context=True,\n", | ||
" device=device,\n", | ||
")\n", | ||
"ligand_mpnn.load_state_dict(checkpoint[\"model_state_dict\"])\n", | ||
"ligand_mpnn.to(device)\n", | ||
"ligand_mpnn.eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pdbfile = \"../pdbs/NbALFA_ALFAtag_AF3.pdb\"\n", | ||
"chains_to_design = \"A\"\n", | ||
"redesigned_residues = \"A1 A3 A4 A5 A7 A8 A9 A13 A14 A15 A19 A20 A21 A23 A24 A25 A26 A27 A39 A41 A44 A45 A46 A48 A50 A52 A53 A67 A68 A69 A72 A73 A74 A75 A76 A77 A78 A79 A80 A81 A82 A83 A84 A85 A86 A88 A89 A91 A92 A93 A95 A97 A99 A100 A102 A114 A116 A118 A119 A120 A121 A123 A124\"\n", | ||
"target_seqs_list = [\n", | ||
" \"SGEVQLQESGGGLVQPGGSLRLSCTASGVTISALNAMAMGWYRQAPGERRVMVAAVSERGNAMYRESVQGRFTVTRDFTNKMVSLQMDNLKPEDTAVYYCHVLEDRVDSFHDYWGQGTQVTVSS:PSRLEEELRRRLTEP\",\n", | ||
" \"GGTVVLTESGGGTVAPGGSATLTATASGVTISALNAMAWGWYRQRPGERPVAVAAVSERGNAMYREDVRGRWTVTADRANKTVSLEMRDLQPEDTATYYPHVLEDRVDSFHDYWGAGVPLTVVP:PSRLEEELRRRLTEP\",\n", | ||
" \"GQVQLQQSAELARPGASVKMSCKASGYTFTSQAPGKGLEWVSAITWNELARPGASVKMSGHIDYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKYPYYGSHWYFDVWGAGTTVTVS:PSRLEEELRRRLTEP\",\n", | ||
" \"PGLRAEDTAVYYCAKYPYELARPGYTFTSQAPGKGLGSHWYFDVWWYFDLYQMNSLRATIRDNSKNTWVSEVWGAGTASKMSCKASGGSVKMEDTAVYYCAKYPYYGSHGAGTDNSKNAVVTVS:PSRLEEELRRRLTEP\",\n", | ||
"]\n", | ||
"\n", | ||
"entropy, loss, perplexity = score_complex(\n", | ||
" ligand_mpnn,\n", | ||
" pdbfile,\n", | ||
" # chains_to_design=chains_to_design,\n", | ||
" redesigned_residues=redesigned_residues,\n", | ||
" seqs_list=target_seqs_list,\n", | ||
" use_side_chain_context=True,\n", | ||
")\n", | ||
"entropy.shape, loss, perplexity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def extract_from_score(output_path):\n", | ||
" with open(output_path, \"rb\") as f:\n", | ||
" output = torch.load(f)\n", | ||
"\n", | ||
" entropy = -(\n", | ||
" torch.tensor(output[\"logits\"][:, :, :20]).softmax(dim=-1).mean(dim=0).log()\n", | ||
" ) # (L, 20)\n", | ||
" target = torch.tensor(output[\"native_sequence\"], dtype=torch.long) # (L,)\n", | ||
" loss = torch.gather(entropy, 1, target.unsqueeze(1)).squeeze() # (L,)\n", | ||
" perplexity = torch.exp(loss.mean()).item() # scalar\n", | ||
"\n", | ||
" return entropy, loss, perplexity\n", | ||
"\n", | ||
"\n", | ||
"def extract_from_sample(output_path):\n", | ||
" with open(output_path, \"rb\") as f:\n", | ||
" output = torch.load(f)\n", | ||
"\n", | ||
" entropy = -output[\"log_probs\"] # (B, L, 20)\n", | ||
" target = output[\"generated_sequences\"] # (B, L)\n", | ||
" loss = torch.gather(entropy, 2, target.unsqueeze(2)).squeeze(2) # (B, L)\n", | ||
" perplexity = torch.exp(loss.mean(dim=-1)) # (B,)\n", | ||
" # redesigned = output[\"chain_mask\"] == 1\n", | ||
" # confidence = torch.exp(-loss[:, redesigned].mean(dim=-1))\n", | ||
"\n", | ||
" return entropy, loss, perplexity" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# !sh \"./score_complex.sh\" \"../pdbs/NbALFA_ALFAtag_AF3.pdb\" \"../results_/score/\"\n", | ||
"# extract_from_score(\"../results_/score/NbALFA_ALFAtag_AF3.pt\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# !sh \"./score_wt.sh\" \"../pdbs/NbALFA_ALFAtag_AF3.pdb\" \"../results/score\"\n", | ||
"# extract_from_score(\"../results/score/NbALFA_ALFAtag_AF3.pt\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# !sh \"./sample_complex.sh\" \"../pdbs/NbALFA_AF3.pdb\" \"../results/sample\"\n", | ||
"# extract_from_sample(\"../results/sample/stats/NbALFA_AF3.pt\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "prodes", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
|
||
python ../run.py \ | ||
--model_type "ligand_mpnn" \ | ||
--checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_020_25.pt" \ | ||
--temperature 0.1 \ | ||
--pdb_path "$1" \ | ||
--out_folder "$2" \ | ||
--redesigned_residues "A1 A3 A4 A5 A7 A8 A9 A13 A14 A15 A19 A20 A21 A23 A24 A25 A26 A27 A39 A41 A44 A45 A46 A48 A50 A52 A53 A67 A68 A69 A72 A73 A74 A75 A76 A77 A78 A79 A80 A81 A82 A83 A84 A85 A86 A88 A89 A91 A92 A93 A95 A97 A99 A100 A102 A114 A116 A118 A119 A120 A121 A123 A124" \ | ||
--omit_AA "C" \ | ||
--ligand_mpnn_use_side_chain_context 1 \ | ||
--save_stats 1 \ | ||
--zero_indexed 1 \ | ||
--fasta_seq_separation ":" \ | ||
--batch_size 256 \ | ||
--number_of_batches 256 \ | ||
--verbose 1 \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
|
||
python ../score.py \ | ||
--model_type "ligand_mpnn" \ | ||
--checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_020_25.pt" \ | ||
--pdb_path "$1" \ | ||
--out_folder "$2" \ | ||
--redesigned_residues "A1 A3 A4 A5 A7 A8 A9 A13 A14 A15 A19 A20 A21 A23 A24 A25 A26 A27 A39 A41 A44 A45 A46 A48 A50 A52 A53 A67 A68 A69 A72 A73 A74 A75 A76 A77 A78 A79 A80 A81 A82 A83 A84 A85 A86 A88 A89 A91 A92 A93 A95 A97 A99 A100 A102 A114 A116 A118 A119 A120 A121 A123 A124" \ | ||
--use_sequence 1 \ | ||
--autoregressive_score 0 \ | ||
--single_aa_score 1 \ | ||
--ligand_mpnn_use_side_chain_context 1 \ | ||
--batch_size 1 \ | ||
--number_of_batches 1 \ | ||
--verbose 1 | ||
|
||
# --chains_to_design "A" \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
|
||
python ../score.py \ | ||
--model_type "ligand_mpnn" \ | ||
--checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_010_25.pt" \ | ||
--pdb_path "$1" \ | ||
--out_folder "$2" \ | ||
--chains_to_design "A" \ | ||
--use_sequence 1 \ | ||
--autoregressive_score 0 \ | ||
--single_aa_score 1 \ | ||
--ligand_mpnn_use_side_chain_context 1 \ | ||
--batch_size 16 \ | ||
--number_of_batches 1 \ | ||
--verbose 1 \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import torch | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = True |
Oops, something went wrong.