Skip to content

Commit 168a445

Browse files
authored
Put flash attention 2 into ProGen2 (#6)
* adapt for esmfold * add tests * fix hidden_states * update esmfold benchmark * Delete scripts directory * update * update progen2 * update FAprogen2 * Update README.md * Delete tests/progen2.py * update progen tests * update .gitignore
1 parent 371d014 commit 168a445

File tree

7 files changed

+1265
-7
lines changed

7 files changed

+1265
-7
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,5 @@ cython_debug/
162162
#.idea/
163163
# vim
164164
*.sw?
165+
166+
tests/progen2/

README.md

+45-1
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ pip install faesm
5353

5454
## ESM2
5555

56-
FAESM is a drop-in replacement for the official ESM implementation. You can use the same code as you would use the official ESM implementation. For example:import torch
56+
FAESM is a drop-in replacement for the official ESM implementation. You can use the same code as you would use the official ESM implementation. For example:
5757

5858
```python
59+
import torch
5960
from faesm.esm import FAEsmForMaskedLM
6061

6162
# Step 1: Load the FAESM model
@@ -73,6 +74,47 @@ print("Repr shape:", outputs['last_hidden_state'].shape) # (batch_size, sequenc
7374
# Step 5: start the repo if the code works for u!
7475
```
7576

77+
78+
79+
80+
## ProGen2
81+
82+
For generative protein language like ProGen2.
83+
84+
```python
85+
import torch
86+
from faesm.progen2 import ProGenForCausalLM
87+
from transformers import AutoTokenizer
88+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
89+
# Avilable model from HF: ["jinyuan22/ProGen2-small", "jinyuan22/ProGen2-base", "jinyuan22/ProGen2-xlarge"]
90+
model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval()
91+
tokenizer = AutoTokenizer.from_pretrained("jinyuan22/ProGen2-small")
92+
93+
sequence = "2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1"
94+
95+
inputs = tokenizer(sequence, return_tensors="pt").to(device)
96+
target = inputs.input_ids[0,...]
97+
with torch.no_grad():
98+
logits = model(inputs.input_ids, labels=inputs.input_ids).logits[0,...]
99+
100+
logits = logits[:-1, ...]
101+
target = target[1:]
102+
103+
bos_token, eos_token = 3, 4
104+
if target[-1] in [bos_token, eos_token]:
105+
logits = logits[:-1, ...]
106+
target = target[:-1]
107+
108+
# remove unused logits
109+
first_token, last_token = 5, 29
110+
logits = logits[:, first_token:(last_token+1)]
111+
target = target - first_token
112+
113+
ce_eval = torch.nn.functional.cross_entropy(input=logits.view(-1, logits.size(-1)), target=target.view(-1), reduction="mean").item()
114+
print(ce_eval)
115+
assert abs(ce_eval - 2.4) < 0.1 # 2.4 is the reference ce for the official progen2-small
116+
```
117+
76118
## ESM-C
77119

78120
Right after EvolutionaryScale release [ESM-C](https://www.evolutionaryscale.ai/blog/esm-cambrian), we follow up with the flash attention version of ESM-C in FAESM. You can run ESM-C easily with the following code:
@@ -85,6 +127,7 @@ input_ids = model.tokenizer(sequence, return_tensors="pt")["input_ids"].to("cuda
85127
output = model(input_ids)
86128
print(output.sequence_logits.shape)
87129
print(output.embeddings.shape)
130+
88131
```
89132

90133
### Training \[WIP\]
@@ -94,6 +137,7 @@ It's recommended to use the flash attention for training. Because in the forward
94137

95138
# Benchmarking
96139

140+
97141
### FAESM vs. Official ESM2
98142

99143
Below is the comparison of peak memory usage and inference time of FAESM with the official ESM2. We show that FAESM can save memory usage by up to 60% and inference time by up to 70% (length 1000). The benchmarking is done on ESM-650M with batch size 8, and a single A100 with 80GB of memory.

assets/figs/FAProGen2_benchmark.png

605 KB
Loading

faesm/esm.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,19 @@ def forward(
589589
attention_mask=attention_mask,
590590
encoder_hidden_states=encoder_hidden_states,
591591
encoder_attention_mask=encoder_attention_mask,
592+
output_hidden_states=output_hidden_states, # For the hidden states
592593
)
593594
sequence_output = outputs[0]
594595
logits = self.lm_head(sequence_output)
595596

596-
result = {
597-
"logits": logits,
598-
"last_hidden_state": sequence_output,
599-
}
597+
if outputs.hidden_states is not None:
598+
result = {
599+
"logits": logits,
600+
"last_hidden_state": sequence_output,
601+
"hidden_states": [x.unsqueeze(0) for x in outputs.hidden_states],
602+
}
603+
else:
604+
result = {"logits": logits, "last_hidden_state": sequence_output}
600605
return result
601606

602607
@classmethod

0 commit comments

Comments
 (0)