Skip to content

Commit

Permalink
update after applying models as a package
Browse files Browse the repository at this point in the history
  • Loading branch information
FridrichMethod committed Nov 12, 2024
1 parent 39618c4 commit 5818a06
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 24 deletions.
4 changes: 3 additions & 1 deletion examples/ligandmpnn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"checkpoint = torch.load(\n",
" \"../models/model_params/ligandmpnn_v_32_020_25.pt\", map_location=device\n",
" \"../models/model_params/ligandmpnn_v_32_020_25.pt\",\n",
" map_location=device,\n",
" weights_only=True,\n",
")\n",
"ligand_mpnn = LigandMPNNBatch(\n",
" model_type=\"ligand_mpnn\",\n",
Expand Down
9 changes: 6 additions & 3 deletions examples/sample_complex.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/bin/bash

python ../run.py \
PYTHONPATH=$(pwd)/.. python -m models.run \
--model_type "ligand_mpnn" \
--checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_020_25.pt" \
--temperature 0.1 \
--temperature 0.3 \
--pdb_path "$1" \
--out_folder "$2" \
--redesigned_residues "A1 A3 A4 A5 A7 A8 A9 A13 A14 A15 A19 A20 A21 A23 A24 A25 A26 A27 A39 A41 A44 A45 A46 A48 A50 A52 A53 A67 A68 A69 A72 A73 A74 A75 A76 A77 A78 A79 A80 A81 A82 A83 A84 A85 A86 A88 A89 A91 A92 A93 A95 A97 A99 A100 A102 A114 A116 A118 A119 A120 A121 A123 A124" \
Expand All @@ -14,4 +14,7 @@ python ../run.py \
--fasta_seq_separation ":" \
--batch_size 256 \
--number_of_batches 256 \
--verbose 1 \
--verbose 1

# --checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_030_25.pt" \
# --temperature 0.2 \
2 changes: 1 addition & 1 deletion examples/score_complex.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

python ../score.py \
PYTHONPATH=$(pwd)/.. python -m models.score \
--model_type "ligand_mpnn" \
--checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_020_25.pt" \
--pdb_path "$1" \
Expand Down
6 changes: 4 additions & 2 deletions examples/score_wt.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

python ../score.py \
PYTHONPATH=$(pwd)/.. python -m models.score \
--model_type "ligand_mpnn" \
--checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_010_25.pt" \
--pdb_path "$1" \
Expand All @@ -12,4 +12,6 @@ python ../score.py \
--ligand_mpnn_use_side_chain_context 1 \
--batch_size 16 \
--number_of_batches 1 \
--verbose 1 \
--verbose 1

# --checkpoint_ligand_mpnn "../model_params/ligandmpnn_v_32_020_25.pt" \
6 changes: 3 additions & 3 deletions models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from prody import writePDB

from .data_utils import (
from models.data_utils import (
alphabet,
element_dict_rev,
featurize,
Expand All @@ -21,7 +21,7 @@
restype_str_to_int,
write_full_PDB,
)
from .model_utils import ProteinMPNN
from models.model_utils import ProteinMPNN

# from sc_utils import Packer, pack_side_chains

Expand Down Expand Up @@ -66,7 +66,7 @@ def main(args) -> None:
else:
print("Choose one of the available models")
sys.exit()
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
if args.model_type == "ligand_mpnn":
atom_context_num = checkpoint["atom_context_num"]
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
Expand Down
2 changes: 1 addition & 1 deletion models/sample_esmif.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from esm.inverse_folding.util import CoordBatchConverter, load_structure
from torch_geometric.nn import MessagePassing

from .globals import AA_ALPHABET
from models.globals import AA_ALPHABET


def sample_complex(
Expand Down
6 changes: 3 additions & 3 deletions models/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import numpy as np
import torch

from .data_utils import (
from models.data_utils import (
alphabet,
element_dict_rev,
featurize,
parse_PDB,
restype_int_to_str,
)
from .model_utils import ProteinMPNN
from models.model_utils import ProteinMPNN


def main(args) -> None:
Expand Down Expand Up @@ -48,7 +48,7 @@ def main(args) -> None:
else:
print("Choose one of the available models")
sys.exit()
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
if args.model_type == "ligand_mpnn":
atom_context_num = checkpoint["atom_context_num"]
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
Expand Down
2 changes: 1 addition & 1 deletion models/score_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from transformers import AutoTokenizer, EsmForMaskedLM

from .globals import AA_ALPHABET, AA_DICT, CHAIN_ALPHABET
from models.globals import AA_ALPHABET, AA_DICT, CHAIN_ALPHABET


def score_complex(
Expand Down
2 changes: 1 addition & 1 deletion models/score_esmif.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from esm.inverse_folding.util import CoordBatchConverter, load_structure
from torch_geometric.nn import MessagePassing

from .globals import AA_ALPHABET, AA_DICT, CHAIN_ALPHABET
from models.globals import AA_ALPHABET, AA_DICT, CHAIN_ALPHABET


def _concatenate_seqs(
Expand Down
16 changes: 9 additions & 7 deletions models/score_ligandmpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import torch

from .data_utils import element_dict_rev, featurize, parse_PDB
from .globals import AA_DICT
from .model_utils import ProteinMPNN, cat_neighbors_nodes
from models.data_utils import element_dict_rev, featurize, parse_PDB
from models.globals import AA_DICT
from models.model_utils import ProteinMPNN, cat_neighbors_nodes


class LigandMPNNBatch(ProteinMPNN):
Expand Down Expand Up @@ -205,12 +205,12 @@ def score_complex(
perplexity = torch.exp(loss.mean(dim=-1)) # (B,)

if verbose:
if "Y" in protein_dict:
atom_masks = protein_dict.get("Y_m", torch.tensor([])).cpu().numpy()
if lig_atom_num := np.sum(atom_masks):
atom_coords = protein_dict["Y"].cpu().numpy()
atom_types = protein_dict["Y_t"].cpu().numpy()
atom_masks = protein_dict["Y_m"].cpu().numpy()
print(
f"The number of ligand atoms parsed is equal to: {np.sum(atom_masks)}"
f"The number of ligand atoms parsed is equal to: {lig_atom_num}"
)
for atom_type, atom_coord, atom_mask in zip(
atom_types, atom_coords, atom_masks
Expand All @@ -228,7 +228,9 @@ def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load(
"../model_params/ligandmpnn_v_32_020_25.pt", map_location=device
"../model_params/ligandmpnn_v_32_020_25.pt",
map_location=device,
weights_only=True,
)
ligand_mpnn = LigandMPNNBatch(
model_type="ligand_mpnn",
Expand Down
2 changes: 1 addition & 1 deletion models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
from Bio.Align import substitution_matrices

from .globals import AA_ALPHABET, AA_DICT
from models.globals import AA_ALPHABET, AA_DICT


def _normalize_submat(submat: np.ndarray) -> np.ndarray:
Expand Down

0 comments on commit 5818a06

Please sign in to comment.