Skip to content

Commit

Permalink
Add simple example to run Hyena inference
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 21, 2024
1 parent 811cd6c commit 6adb98f
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 10 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,7 @@ jobs:
python examples/run_geneformer.py
- name: Execute UCE
run: |
python examples/run_uce.py
python examples/run_uce.py
- name: Execute Hyena
run: |
python examples/run_hyena_dna.py
5 changes: 4 additions & 1 deletion examples/run_hyena_dna.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from helical.models.hyena_dna.model import HyenaDNA, HyenaDNAConfig
hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256")
model = HyenaDNA(configurer = hyena_config)
print("Done")
sequence = 'ACTG' * int(1024/4)
data = model.process_data(sequence)
embeddings = model.get_embeddings(data)
print(embeddings.shape)
5 changes: 4 additions & 1 deletion helical/models/hyena_dna/hyena_dna_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ def __init__(
"hyenadna-tiny-1k-seqlen": {
'd_model': 128,
'd_inner': 512,
'max_length': 1024,
},
"hyenadna-tiny-1k-seqlen-d256": {
'd_model': 256,
'd_inner': 1024,
'max_length': 1024, # TODO double check this
}
}

Expand All @@ -97,7 +99,8 @@ def __init__(
"residual_in_fp32": residual_in_fp32,
"pad_vocab_size_multiple": pad_vocab_size_multiple,
"return_hidden_state": return_hidden_state,
"layer": layer
"layer": layer,
"max_length": self.model_map[model_name]['max_length']
}


Expand Down
34 changes: 30 additions & 4 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from helical.models.hyena_dna.hyena_dna_config import HyenaDNAConfig
from helical.models.helical import HelicalBaseModel
from helical.models.hyena_dna.pretrained_model import HyenaDNAPreTrainedModel
import torch
from .standalone_hyenadna import CharacterTokenizer

class HyenaDNA(HelicalBaseModel):
"""HyenaDNA model."""
default_configurer = HyenaDNAConfig()
Expand All @@ -12,10 +15,33 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
self.config = configurer.config
self.log = logging.getLogger("Hyena-DNA-Model")

self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.model = HyenaDNAPreTrainedModel().from_pretrained(self.config)

def process_data(self):
pass
# create tokenizer
self.tokenizer = CharacterTokenizer(
characters=['A', 'C', 'G', 'T', 'N'], # add DNA characters, N is uncertain
model_max_length=self.config['max_length'] + 2, # to account for special tokens, like EOS
add_special_tokens=False, # we handle special tokens elsewhere
padding_side='left', # since HyenaDNA is causal, we pad on the left
)

# prep model and forward
self.model.to(self.device)
self.model.eval()


def process_data(self, sequence):

tok_seq = self.tokenizer(sequence)
tok_seq = tok_seq["input_ids"] # grab ids

# place on device, convert to tensor
tok_seq = torch.LongTensor(tok_seq).unsqueeze(0) # unsqueeze for batch dim
tok_seq = tok_seq.to(self.device)
return tok_seq

def get_embeddings(self):
pass
def get_embeddings(self, tok_seq):
with torch.inference_mode():
return self.model(tok_seq)
1 change: 0 additions & 1 deletion helical/models/hyena_dna/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from transformers import PreTrainedModel
import re
from .standalone_hyenadna import HyenaDNAModel
from .standalone_hyenadna import CharacterTokenizer

# helper 1
def inject_substring(orig_str):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ scikit-learn==1.2.2
gitpython==3.1.43
torch==2.0.0
accelerate==0.29.3
transformers==4.35.0
transformers==4.26.1
loompy==3.0.7
scib==1.1.5
datasets==2.14.7
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
'gitpython==3.1.43',
'torch==2.3.0',
'accelerate==0.29.3',
'transformers==4.35.0',
'transformers==4.26.1',
'loompy==3.0.7',
'scib==1.1.5',
'datasets==2.14.7',
Expand Down

0 comments on commit 6adb98f

Please sign in to comment.