Skip to content

Commit 1f39f92

Browse files
authored
Update PythonListCustomToolGenerator to support overriding system prompt (#271)
Summary: Support user supplied prompt template in PythonListCustomToolGenerator. This is to allow user to provided their own system prompt without having to format function descirptions. Test Plan: python -m unittest llama_models.llama3.tests.prompt_templates.test_system_prompts
1 parent ecf2f12 commit 1f39f92

File tree

2 files changed

+73
-8
lines changed

2 files changed

+73
-8
lines changed

models/llama3/prompt_templates/system_prompts.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import textwrap
99
from datetime import datetime
10-
from typing import Any, List
10+
from typing import Any, List, Optional
1111

1212
from llama_models.llama3.api.datatypes import (
1313
BuiltinTool,
@@ -215,14 +215,33 @@ def data_examples(self) -> List[List[ToolDefinition]]:
215215

216216

217217
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
218-
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
218+
DEFAULT_PROMPT = textwrap.dedent(
219+
"""
220+
You are an expert in composing functions. You are given a question and a set of possible functions.
221+
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
222+
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
223+
also point it out. You should only return the function call in tools call sections.
224+
225+
{{ function_description }}
226+
""".strip(
227+
"\n"
228+
)
229+
)
230+
231+
def gen(
232+
self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None
233+
) -> PromptTemplate:
234+
system_prompt = system_prompt or self.DEFAULT_PROMPT
235+
return PromptTemplate(
236+
system_prompt,
237+
{"function_description": self._gen_function_description(custom_tools)},
238+
)
239+
240+
def _gen_function_description(
241+
self, custom_tools: List[ToolDefinition]
242+
) -> PromptTemplate:
219243
template_str = textwrap.dedent(
220244
"""
221-
You are an expert in composing functions. You are given a question and a set of possible functions.
222-
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
223-
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
224-
also point it out. You should only return the function call in tools call sections.
225-
226245
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
227246
You SHOULD NOT include any other text in the response.
228247
@@ -263,7 +282,7 @@ def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
263282
return PromptTemplate(
264283
template_str.strip("\n"),
265284
{"tools": [t.model_dump() for t in custom_tools]},
266-
)
285+
).render()
267286

268287
def data_examples(self) -> List[List[ToolDefinition]]:
269288
return [

models/llama3/tests/prompt_templates/test_system_prompts.py

+46
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,49 @@ def test_llama_3_2_system_zero_shot(self):
145145
"""
146146
)
147147
self.check_generator_output(generator, expected_text.strip("\n"))
148+
149+
def test_llama_3_2_provided_system_prompt(self):
150+
generator = PythonListCustomToolGenerator()
151+
expected_text = textwrap.dedent(
152+
"""
153+
Overriding message.
154+
155+
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
156+
You SHOULD NOT include any other text in the response.
157+
158+
Here is a list of functions in JSON format that you can invoke.
159+
160+
[
161+
{
162+
"name": "get_weather",
163+
"description": "Get weather info for places",
164+
"parameters": {
165+
"type": "dict",
166+
"required": ["city"],
167+
"properties": {
168+
"city": {
169+
"type": "string",
170+
"description": "The name of the city to get the weather for"
171+
},
172+
"metric": {
173+
"type": "string",
174+
"description": "The metric for weather. Options are: celsius, fahrenheit",
175+
"default": "celsius"
176+
}
177+
}
178+
}
179+
}
180+
]"""
181+
)
182+
user_system_prompt = textwrap.dedent(
183+
"""
184+
Overriding message.
185+
186+
{{ function_description }}
187+
"""
188+
)
189+
example = generator.data_examples()[0]
190+
191+
pt = generator.gen(example, user_system_prompt)
192+
text = pt.render()
193+
assert text == expected_text, f"Expected:\n{expected_text}\nActual:\n{text}"

0 commit comments

Comments
 (0)