Skip to content

Commit 012f33c

Browse files
authored
Add a python script for generating text using huggingface gpt2 (#2983)
* Add a script for generating text using gpt2 Signed-off-by: Tung D. Le <[email protected]> --------- Signed-off-by: Tung D. Le <[email protected]>
1 parent c80d3e0 commit 012f33c

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
##################### run_gpt2_from_huggingface.py #############################
5+
#
6+
# Copyright 2019-2024 The IBM Research Authors.
7+
#
8+
################################################################################
9+
#
10+
# This script is to run GPT2 from HuggingFace.
11+
# Command: ONNX_MLIR_HOME=/workdir/onnx-mlir/build/Debug python run_gpt2_from_huggingface.py 2>&1 | tee log.txt
12+
#
13+
# When running this script for the first time, it will download onnx models from
14+
# HuggingFace and compile the models. The onnx models and compiled models are
15+
# cached in the current folder (by default).
16+
#
17+
# Change compile_flags if targeting a different machine.
18+
################################################################################
19+
20+
import os
21+
import sys
22+
import time
23+
import json
24+
import requests as req
25+
from urllib.request import urlretrieve
26+
27+
import numpy as np
28+
from transformers import AutoTokenizer
29+
30+
# Include runtime directory in python paths, so PyRuntime can be imported.
31+
RUNTIME_DIR = os.path.join(os.environ["ONNX_MLIR_HOME"], "lib")
32+
sys.path.append(RUNTIME_DIR)
33+
try:
34+
from PyCompileAndRuntime import OMCompileExecutionSession
35+
except ImportError:
36+
raise ImportError(
37+
"Looks like you did not build the PyRuntime target, build it by running `make PyRuntime`."
38+
"You may need to set ONNX_MLIR_HOME to `onnx-mlir/build/Debug` since `make PyRuntime` outputs to `build/Debug` by default"
39+
)
40+
41+
# Information to download onnx models from HuggingFace.
42+
model_name_or_path = "gpt2" # can be gpt2, gpt2-medium, gpt2-large
43+
decoder_model_name = "decoder_model.onnx"
44+
decoder_with_past_model_name = "decoder_with_past_model.onnx"
45+
config_json_name = "config.json"
46+
decoder_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_model_name}"
47+
decoder_with_past_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_with_past_model_name}"
48+
config_json_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{config_json_name}"
49+
50+
# Local directories for caching the model.
51+
cache_dir = "./"
52+
decoder_model_path = f"{cache_dir}/{decoder_model_name}"
53+
decoder_with_past_model_path = f"{cache_dir}/{decoder_with_past_model_name}"
54+
config_json_path = f"{cache_dir}/{config_json_name}"
55+
56+
# Download the model to a local dir.
57+
if not os.path.exists(decoder_model_path):
58+
print(f"Downloading {decoder_url}")
59+
urlretrieve(decoder_url, decoder_model_path)
60+
print("Done")
61+
if req.head(f"{decoder_url}_data", allow_redirects=True).status_code == 200:
62+
print(f"Downloading {decoder_url}_data")
63+
urlretrieve(decoder_url + "_data", decoder_model_path + "_data")
64+
print("Done")
65+
if not os.path.exists(decoder_with_past_model_path):
66+
print(f"Downloading {decoder_with_past_url}")
67+
urlretrieve(decoder_with_past_url, decoder_with_past_model_path)
68+
print("Done")
69+
if req.head(f"{decoder_with_past_url}_data", allow_redirects=True).status_code == 200:
70+
print(f"Downloading {decoder_with_past_url}_data")
71+
urlretrieve(decoder_with_past_url + "_data", decoder_with_past_model_path + "_data")
72+
print("Done")
73+
if not os.path.exists(config_json_path):
74+
print(f"Downloading the config json file {config_json_url}")
75+
urlretrieve(config_json_url, config_json_path)
76+
print("Done")
77+
78+
with open(config_json_path) as f:
79+
cfg = json.load(f)
80+
print("Model configuration: {}\n".format(cfg))
81+
num_attention_heads = cfg["n_head"]
82+
hidden_size = cfg["n_embd"]
83+
num_layers = cfg["n_layer"]
84+
eos_token_id = cfg["eos_token_id"]
85+
86+
# Create CompileExecutionSession to compile and run the model,
87+
compile_flags = "-O3 -v --onnx-op-stats TXT"
88+
# compile_flags = "-O3 -mcpu=z16 -maccel=NNPA -v --onnx-op-stats TXT"
89+
decoder_sess = OMCompileExecutionSession(
90+
decoder_model_path, compile_flags + " -tag=decoder", reuse_compiled_model=1
91+
)
92+
decoder_with_past_sess = OMCompileExecutionSession(
93+
decoder_with_past_model_path,
94+
compile_flags + " -tag=decoder_with_past",
95+
reuse_compiled_model=1,
96+
)
97+
98+
# Setup a tokenizer.
99+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
100+
101+
# Tokenize the input text.
102+
prompt_text = "Which is the highest mountain in Japan?"
103+
pt_inputs = tokenizer(prompt_text, return_tensors="pt")
104+
output_length = 13
105+
106+
# Generate tokens.
107+
ts = []
108+
num_runs = 3
109+
for r in range(num_runs):
110+
output_ids = []
111+
kv_cache = None
112+
113+
t = 0
114+
attention_mask = pt_inputs["attention_mask"].numpy()
115+
inputs = [pt_inputs["input_ids"].numpy(), attention_mask]
116+
for step in range(output_length):
117+
t0 = time.time()
118+
if kv_cache is None:
119+
outputs = decoder_sess.run(inputs)
120+
else:
121+
outputs = decoder_with_past_sess.run(inputs)
122+
t_elap = time.time() - t0
123+
t += t_elap
124+
# Greedy approach is used here.
125+
logits = outputs[0][:, -1, :]
126+
next_id = np.argmax(logits, axis=1, keepdims=True)
127+
kv_cache = outputs[1:]
128+
# Only for batchsize = 1
129+
attention_mask = np.append(
130+
attention_mask, np.array([[1]], dtype=np.int64), axis=1
131+
)
132+
inputs = [next_id] + kv_cache + [attention_mask]
133+
output_ids += [next_id[0][0]]
134+
135+
ts += [t]
136+
if r == num_runs - 1:
137+
# Expected answer: "The highest mountain in Japan is the Mt. Fuji."
138+
print("Prompt: {}".format(prompt_text))
139+
print("Generated words: {}\n".format(tokenizer.decode(output_ids).strip()))
140+
141+
print("times", ts)
142+
t = np.min(ts)
143+
print("t_elap: %.2f seconds" % (t))
144+
print(
145+
"Latency: {} msec/token, thoughput: {} tokens/sec".format(
146+
t / output_length * 1000, output_length / t
147+
)
148+
)

0 commit comments

Comments
 (0)