Skip to content

Commit 8aa0846

Browse files
committed
fix minor error
1 parent 89f1782 commit 8aa0846

File tree

13 files changed

+954
-1025
lines changed

13 files changed

+954
-1025
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ The results of unconditional protein sequence generation of DPLM of different sc
111111
| 650M | 74.00 (+0.69) | 85.61 (+1.31) | 85.91 (+1.09) | 88.16 (+1.26) | 82.58 (+0.87) | 84.38 (+2.85) | 83.87 (+2.31) | 83.00 (+2.08) | 84.92 (+6.21) | 81.51 (+9.41) |
112112
| 3B | 77.78 (+4.47) | 86.16 (+1.86) | 87.39 (+2.57) | 90.06 (+3.16) | 87.43 (+5.72) | 86.01 (+4.48) | 84.64 (+3.08) | 85.88 (+4.96) | 85.93 (+7.22) | 83.86 (+11.76) |
113113

114-
To generate new protein sequences using a pre-trained DPLM model:
114+
To generate new protein sequences using a pre-trained DPLM model, and evaluate with ESMFold:
115115

116116
```bash
117117
model_name=dplm_650m # choose from dplm_150m, dplm_650m, dplm_3b
@@ -217,10 +217,10 @@ eval_sc=false
217217
# TMscore and pLDDT during generation,
218218
# thus siginificantly increase the evaluation time.
219219

220-
python test.py \
220+
python test.py \
221221
experiment_path=${exp_path} \
222222
data_split=test ckpt_path=best.ckpt mode=predict \
223-
task.generator.max_iter=100 task.generator.eval_sc=${eval_sc}
223+
+task.generator.max_iter=100 +task.generator.eval_sc=${eval_sc}
224224
```
225225

226226
## Representation Learning

analysis/motif_analysis.py

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import seaborn as sns
2+
import matplotlib.pyplot as plt
3+
import pandas as pd
4+
import os
5+
import MDAnalysis as mda
6+
from MDAnalysis.analysis import rms
7+
from ast import literal_eval
8+
import subprocess
9+
from Bio import PDB
10+
import numpy as np
11+
import argparse
12+
13+
14+
def analysis(args):
15+
start_idx_dict = {
16+
"1prw": [15, 51],
17+
"1bcf": [17, 46, 90, 122],
18+
"5tpn": [108],
19+
"3ixt": [0],
20+
"4jhw": [37, 144],
21+
"4zyp": [357],
22+
"5wn9": [1],
23+
"5ius": [34, 88],
24+
"5yui": [89, 114, 194],
25+
"6vw1": [5, 45],
26+
"1qjg": [13, 37, 98],
27+
"1ycr": [2],
28+
"2kl8": [0, 27],
29+
"7mrx": [25],
30+
"5trv": [45],
31+
"6e6r": [22],
32+
"6exz": [25],
33+
}
34+
end_idx_dict = {
35+
"1prw": [34, 70],
36+
"1bcf": [24, 53, 98, 129],
37+
"5tpn": [126],
38+
"3ixt": [23],
39+
"4jhw": [43, 159],
40+
"4zyp": [371],
41+
"5wn9": [20],
42+
"5ius": [53, 109],
43+
"5yui": [93, 116, 196],
44+
"6vw1": [23, 63],
45+
"1qjg": [13, 37, 98],
46+
"1ycr": [10],
47+
"2kl8": [6, 78],
48+
"7mrx": [46],
49+
"5trv": [69],
50+
"6e6r": [34],
51+
"6exz": [39],
52+
}
53+
54+
def calculate_avg_plddt(pdb_file):
55+
# 创建PDB解析器
56+
parser = PDB.PDBParser(QUIET=True)
57+
58+
# 解析PDB文件
59+
structure = parser.get_structure("protein", pdb_file)
60+
61+
# 获取所有的plDDT值
62+
plddt_values = []
63+
for model in structure:
64+
for chain in model:
65+
for residue in chain:
66+
if "CA" in residue:
67+
# 获取 CA 原子的 B-factor,并假设它存储了 plDDT 值
68+
ca_atom = residue["CA"]
69+
plddt = ca_atom.get_bfactor()
70+
plddt_values.append(plddt)
71+
72+
# 计算平均plDDT值
73+
if plddt_values:
74+
avg_plddt = np.mean(plddt_values)
75+
return avg_plddt
76+
else:
77+
raise NotImplementedError
78+
79+
def calc_rmsd_tmscore(
80+
pdb_name,
81+
reference_PDB,
82+
scaffold_pdb_path=None,
83+
scaffold_info_path=None,
84+
ref_motif_starts=[30],
85+
ref_motif_ends=[44],
86+
output_path=None,
87+
):
88+
"Calculate RMSD between reference structure and generated structure over the defined motif regions"
89+
90+
motif_df = pd.read_csv(
91+
os.path.join(scaffold_info_path, f"{pdb_name}.csv"), index_col=0
92+
) # , nrows=num_structures)
93+
results = []
94+
for pdb in os.listdir(
95+
os.path.join(scaffold_pdb_path, f"{pdb_name}")
96+
): # This needs to be in numerical order to match new_starts file
97+
if not pdb.endswith(".pdb"):
98+
continue
99+
ref = mda.Universe(reference_PDB)
100+
predict_PDB = os.path.join(
101+
os.path.join(scaffold_pdb_path, f"{pdb_name}"), pdb
102+
)
103+
u = mda.Universe(predict_PDB)
104+
105+
ref_selection = "name CA and resnum "
106+
u_selection = "name CA and resnum "
107+
i = int(pdb.split("_")[1].split(".")[0])
108+
new_motif_starts = literal_eval(motif_df["start_idxs"].iloc[i])
109+
new_motif_ends = literal_eval(motif_df["end_idxs"].iloc[i])
110+
111+
for j in range(len(ref_motif_starts)):
112+
ref_selection += (
113+
str(ref_motif_starts[j]) + ":" + str(ref_motif_ends[j]) + " "
114+
)
115+
u_selection += (
116+
str(new_motif_starts[j] + 1)
117+
+ ":"
118+
+ str(new_motif_ends[j] + 1)
119+
+ " "
120+
)
121+
print("U SELECTION", u_selection)
122+
print("SEQUENCE", i)
123+
print("ref", ref.select_atoms(ref_selection).resnames)
124+
print("gen", u.select_atoms(u_selection).resnames)
125+
# This asserts that the motif sequences are the same - if you get this error something about your indices are incorrect - check chain/numbering
126+
assert len(ref.select_atoms(ref_selection).resnames) == len(
127+
u.select_atoms(u_selection).resnames
128+
), "Motif lengths do not match, check PDB preprocessing \
129+
for extra residues"
130+
131+
assert (
132+
ref.select_atoms(ref_selection).resnames
133+
== u.select_atoms(u_selection).resnames
134+
).all(), "Resnames for motifRMSD do not match, check indexing"
135+
rmsd = rms.rmsd(
136+
u.select_atoms(u_selection).positions,
137+
# coordinates to align
138+
ref.select_atoms(ref_selection).positions,
139+
# reference coordinates
140+
center=True, # subtract the center of geometry
141+
superposition=True,
142+
) # superimpose coordinates
143+
144+
temp_file = open(os.path.join(output_path, "temp_tmscores.txt"), "w")
145+
146+
subprocess.call(
147+
["./analysis/TMscore", reference_PDB, predict_PDB, "-seq"],
148+
stdout=temp_file,
149+
)
150+
with open(os.path.join(output_path, "temp_tmscores.txt"), "r") as f:
151+
for line in f:
152+
if len(line.split()) > 1 and "TM-score" == line.split()[0]:
153+
tm_score = line.split()[2]
154+
break
155+
156+
# plddt = float(predict_PDB.split('_')[-1][:-4])
157+
# 计算平均plDDT值
158+
plddt = calculate_avg_plddt(predict_PDB)
159+
results.append((pdb_name, i, rmsd, plddt, tm_score))
160+
return results
161+
162+
scaffold_dir = args.scaffold_dir
163+
output_dir = os.path.join(scaffold_dir, "scaffold_results")
164+
os.makedirs(output_dir, exist_ok=True)
165+
166+
results = []
167+
for pdb in start_idx_dict.keys():
168+
print(pdb)
169+
ref_motif_starts = start_idx_dict[pdb]
170+
ref_motif_ends = end_idx_dict[pdb]
171+
reference_PDB = os.path.join(
172+
"./data-bin/scaffolding-pdbs", pdb + "_reference.pdb"
173+
)
174+
with open(reference_PDB) as f:
175+
line = f.readline()
176+
ref_basenum = int(line.split()[5])
177+
ref_motif_starts = [num + ref_basenum for num in ref_motif_starts]
178+
ref_motif_ends = [num + ref_basenum for num in ref_motif_ends]
179+
results += calc_rmsd_tmscore(
180+
pdb_name=pdb,
181+
reference_PDB=reference_PDB,
182+
scaffold_pdb_path=f"{scaffold_dir}/scaffold_fasta/esmfold_pdb",
183+
scaffold_info_path=f"{scaffold_dir}/scaffold_info",
184+
ref_motif_starts=ref_motif_starts,
185+
ref_motif_ends=ref_motif_ends,
186+
output_path=output_dir,
187+
)
188+
189+
results = pd.DataFrame(
190+
results, columns=["pdb_name", "index", "rmsd", "plddt", "tmscore"]
191+
)
192+
results.to_csv(os.path.join(output_dir, "rmsd_tmscore.csv"), index=False)
193+
194+
195+
def cal_success_scaffold(pdb):
196+
total = len(pdb)
197+
pdb["total"] = total
198+
pdb = pdb[(pdb["rmsd"] < 1.0) & (pdb["plddt"] > 70)]
199+
return pdb
200+
201+
202+
def motif_evaluation(args):
203+
analysis(args)
204+
205+
output_dir = os.path.join(args.scaffold_dir, "scaffold_results")
206+
rmsd_tmscore = pd.read_csv(os.path.join(output_dir, "rmsd_tmscore.csv"))
207+
success_scaffold = rmsd_tmscore.groupby("pdb_name", as_index=False).apply(
208+
cal_success_scaffold
209+
)
210+
success_scaffold_count = success_scaffold.groupby("pdb_name").size()
211+
success_scaffold_count = success_scaffold_count.reset_index(name="success_count")
212+
213+
all_pdb = list(rmsd_tmscore["pdb_name"].unique())
214+
success_pdb = list(success_scaffold_count["pdb_name"])
215+
failed_pdb = list(set(all_pdb) - set(success_pdb))
216+
failed_scaffold_count = {
217+
"pdb_name": failed_pdb,
218+
"success_count": [0] * len(failed_pdb),
219+
}
220+
results = pd.concat(
221+
[success_scaffold_count, pd.DataFrame(failed_scaffold_count)]
222+
).sort_values("pdb_name")
223+
results.to_csv(os.path.join(output_dir, "result.csv"))
224+
print(results)
225+
226+
227+
def main():
228+
parser = argparse.ArgumentParser()
229+
230+
parser.add_argument("--scaffold_dir", type=str, default="./generation-results")
231+
232+
args = parser.parse_args()
233+
234+
motif_evaluation(args)
235+
236+
237+
if __name__ == "__main__":
238+
main()

analysis/plddt_calculate.sh

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ pdb_path=$output_dir/esmfold_pdb
1515
mkdir -p $pdb_path
1616

1717
echo 'folding by ESMFold'
18-
python cal_plddt_dir.py -i ${output_dir} -o ${pdb_path} --max-tokens-per-batch ${max_tokens} \
19-
-m ${ROOTDIR}/cache-dir
20-
21-
echo "============================Finish Evaluation=============================="
18+
python analysis/cal_plddt_dir.py -i ${output_dir} -o ${pdb_path} --max-tokens-per-batch ${max_tokens}
2219

20+
echo "============================Finish Evaluation=============================="

configs/datamodule/uniref50.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ data_dir: ${paths.data_dir}/uniref50
55
# dataloader related
66
max_tokens: 6000
77
max_len: 1022
8-
sort: false
98
num_workers: 8
10-
pin_memory: true
119

1210
mini_run: false

0 commit comments

Comments
 (0)