forked from Yushi-Hu/IC-DST
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompting.py
executable file
·157 lines (135 loc) · 5.25 KB
/
prompting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from utils.sql import slot_values_to_seq_sql
table_prompt = """
CREATE TABLE hotel(
name text,
pricerange text CHECK (pricerange IN (dontcare, cheap, moderate, expensive)),
type text CHECK (type IN (hotel, guest house)),
parking text CHECK (parking IN (dontcare, yes, no)),
book_stay int,
book_day text,
book_people int,
area text CHECK (area IN (dontcare, centre, east, north, south, west)),
stars int CHECK (stars IN (dontcare, 0, 1, 2, 3, 4, 5)),
internet text CHECK (internet IN (dontcare, yes, no))
)
/*
4 example rows:
SELECT * FROM hotel LIMIT 4;
name pricerange type parking book_stay book_day book_people area stars internet
a and b guest house moderate guest house dontcare 3 friday 5 east 4 yes
ashley hotel expensive hotel yes 2 thursday 5 north 5 yes
el shaddia guest house cheap guest house yes 5 friday 2 centre dontcare no
express by holiday inn cambridge dontcare guest house yes 3 monday 2 east dontcare no
*/
CREATE TABLE train(
destination text,
departure text,
day text,
book_people int,
leaveat text,
arriveby text
)
/*
3 example rows:
SELECT * FROM train LIMIT 3;
destination departure day book_people leaveat arriveby
london kings cross cambridge monday 6 dontcare 05:51
cambridge stansted airport dontcare 1 20:24 20:52
peterborough cambridge saturday 2 12:06 12:56
*/
CREATE TABLE attraction(
name text,
area text CHECK (area IN (dontcare, centre, east, north, south, west)),
type text CHECK (type IN (architecture, boat, church, cinema, college, concert hall, entertainment, hotspot, multiple sports, museum, nightclub, park, special, swimming pool, theatre))
)
/*
4 example rows:
SELECT * FROM attraction LIMIT 4;
name area type
abbey pool and astroturf pitch centre swimming pool
adc theatre centre theatre
all saints church dontcare architecture
castle galleries centre museum
*/
CREATE TABLE restaurant(
name text,
food text,
pricerange text CHECK (pricerange IN (dontcare, cheap, moderate, expensive)),
area text CHECK (area IN (centre, east, north, south, west)),
book_time text,
book_day text,
book_people int
)
/*
5 example rows:
SELECT * FROM restaurant LIMIT 5;
name food pricerange area book_time book_day book_people
pizza hut city centre italian dontcare centre 13:30 wednesday 7
the missing sock international moderate east dontcare dontcare 2
golden wok chinese moderate north 17:11 friday 4
cambridge chop house dontcare expensive center 08:43 monday 5
darrys cookhouse and wine shop modern european expensive center 11:20 saturday 8
*/
CREATE TABLE taxi(
destination text,
departure text,
leaveat text,
arriveby text
)
/*
3 example rows:
SELECT * FROM taxi LIMIT 3;
destination departure leaveat arriveby
copper kettle royal spice 14:45 15:30
magdalene college university arms hotel dontcare 15:45
lovell lodge da vinci pizzeria 11:45 dontcare
*/
-- Using valid SQLite, answer the following multi-turn conversational questions for the tables provided above.
"""
def conversion(prompt, reverse=False):
conversion_dict = {"leaveat": "depart_time", "arriveby": "arrive_by_time",
"book_stay": "book_number_of_days",
"food": "food_type"}
reverse_conversion_dict = {v: k for k, v in conversion_dict.items()}
used_dict = reverse_conversion_dict if reverse else conversion_dict
for k, v in used_dict.items():
prompt = prompt.replace(k, v)
return prompt
def get_prompt(data_item, examples, given_context=None, n_examples=None):
"""
You can try different prompt in here.
"""
question_item = data_item
prompt_text = f"{conversion(table_prompt)}\n"
max_n_examples = len(examples)
if n_examples is not None:
max_n_examples = n_examples
# in case for zero-shot learning
if max_n_examples > 0:
for example_id, example in enumerate(examples[-max_n_examples:]):
prompt_text += f"Example #{example_id + 1}\n"
# remove multiple choice in last slot values
last_slot_values = {s: v.split(
'|')[0] for s, v in example['last_slot_values'].items()}
prompt_text += f"[context] {conversion(', '.join({f'{slot}: {value}' for slot, value in last_slot_values.items()}))}\n"
last_sys_utt = example['dialog']['sys'][-1]
if last_sys_utt == 'none':
last_sys_utt = ''
prompt_text += f"[system] {last_sys_utt}\n"
prompt_text += f"Q: [user] {example['dialog']['usr'][-1]}\n"
prompt_text += f"SQL: {conversion(slot_values_to_seq_sql(example['turn_slot_values']))};\n"
prompt_text += "\n\n"
prompt_text += f"Example #{max_n_examples + 1}\n"
if given_context is None:
last_slot_values = {s: v.split(
'|')[0] for s, v in question_item['last_slot_values'].items()}
else:
last_slot_values = given_context
prompt_text += f"[context] {conversion(', '.join({f'{slot}: {value}' for slot, value in last_slot_values.items()}))}\n"
last_sys_utt = question_item['dialog']['sys'][-1]
if last_sys_utt == 'none':
last_sys_utt = ''
prompt_text += f"[system] {last_sys_utt}\n"
prompt_text += f"Q: [user] {question_item['dialog']['usr'][-1]}\n"
prompt_text += "SQL: SELECT * FROM"
return prompt_text