Skip to content

Commit

Permalink
Fix deepseek r1 tool call example polyfill (template newly adds trail…
Browse files Browse the repository at this point in the history
…ing <think>) (#52)

* Fix deepseek r1 tool call example polyfill (their template newly adds trailing <think>)

* test tool outputs for common templates

* tests: align extra context in c++ w/ python + remove python tojson override
  • Loading branch information
ochafik authored Feb 9, 2025
1 parent e259cda commit 7eb5202
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 75 deletions.
19 changes: 17 additions & 2 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,25 @@ class chat_template {
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
if (full.find(prefix) != 0) {
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
tool_call_example_ = full.substr(prefix.size());
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
Expand Down
67 changes: 38 additions & 29 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def raise_exception(message: str):
raise ValueError(message)


def tojson(eval_ctx, value, indent=None):
return json.dumps(value, indent=indent)

TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')


Expand Down Expand Up @@ -114,16 +111,22 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext
# print(out, file=sys.stderr)
return out
except BaseException as e:
# print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True)
# print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True)
return ""

def __init__(self, template, known_eos_tokens, env=None):
def __init__(self, template, env=None, filters=None, global_functions=None):
if not env:
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols]
)
if filters:
for name, func in filters.items():
env.filters[name] = func
if global_functions:
for name, func in global_functions.items():
env.globals[name] = func
self.env = env
self.template = env.from_string(template)

Expand Down Expand Up @@ -243,15 +246,24 @@ def make_tool_call(tool_name, arguments):
}
prefix = self.try_raw_render([user_msg], add_generation_prompt=True)
full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False)
if not full.startswith(prefix):
for known_eos_token in known_eos_tokens:
prefix = prefix.rstrip()
if prefix.endswith(known_eos_token):
prefix = prefix[:-len(known_eos_token)]
break
if not full.startswith(prefix):

common_prefix_length = 0
for i in range(min(len(prefix), len(full))):
if prefix[i] != full[i]:
break
if prefix[i] == '<':
# DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
# but it removes thinking tags for past messages.
# The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
continue
common_prefix_length = i + 1

example = full[common_prefix_length:]
if "tool_name" not in example and "some_value" not in example:
print("Failed to infer a tool call example (possible template bug)", file=sys.stderr)
self.tool_call_example = full[len(prefix):]
else:
self.tool_call_example = example

except Exception as e:
print(f"Failed to generate tool call example: {e}", file=sys.stderr)

Expand Down Expand Up @@ -321,7 +333,11 @@ def apply(self, context):
message['content'] = [{"type": "text", "text": message['content']}]

try:
return self.template.render(**context)
out = self.template.render(**context)
out = out.replace("\\u0027", "'")
out = out.replace('&#34;', '"')
out = out.replace('&#39;', "'")
return out
except Exception as e1:
for message in context['messages']:
if message.get("content") is None:
Expand Down Expand Up @@ -350,21 +366,14 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c
async with aiofiles.open(template_file, 'w') as f:
await f.write(template_src)

known_eos_tokens = [
"<|END_OF_TURN_TOKEN|>",
"<end_of_turn>",
"</s>",
"<|im_end|>",
"<|eom_id|>",
"<|eot_id|>",
"<|end▁of▁sentence|>",
]

template = chat_template(template_src, known_eos_tokens)
template.env.filters['safe'] = lambda x: x
template.env.filters['tojson'] = tojson
template.env.globals['raise_exception'] = raise_exception
template.env.globals['strftime_now'] = strftime_now
template = chat_template(template_src,
filters={
'safe': lambda x: x,
},
global_functions={
'raise_exception': raise_exception,
'strftime_now': strftime_now,
})
caps = template.original_caps

if not context_files:
Expand Down
8 changes: 3 additions & 5 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ target_link_libraries(test-polyfills PRIVATE
)
if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
endif()

if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
gtest_discover_tests(test-polyfills)
add_test(NAME test-polyfills COMMAND test-polyfills)
set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endif()

add_executable(test-capabilities test-capabilities.cpp)
Expand Down Expand Up @@ -82,6 +79,7 @@ set(MODEL_IDS
MiniMaxAI/MiniMax-Text-01
indischepartij/MiniCPM-3B-OpenHermes-2.5-v2
mattshumer/Reflection-Llama-3.1-70B
meetkai/functionary-medium-v3.1
meetkai/functionary-medium-v3.2
meta-llama/Llama-3.1-8B-Instruct # Gated
meta-llama/Llama-3.2-3B-Instruct # Gated
Expand Down
3 changes: 2 additions & 1 deletion tests/contexts/simple.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
],
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>"
"eos_token": "<|endoftext|>",
"tools_in_user_message": false
}
3 changes: 2 additions & 1 deletion tests/contexts/system.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
],
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>"
"eos_token": "<|endoftext|>",
"tools_in_user_message": false
}
57 changes: 29 additions & 28 deletions tests/contexts/tool_use.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>",
"tools_in_user_message": false,
"builtin_tools": [
"wolfram_alpha",
"brave_search"
Expand All @@ -96,72 +97,72 @@
"todays_date": "2024-09-03",
"tools": [
{
"type": "function",
"function": {
"name": "ipython",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"name": "ipython",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to run in the ipython interpreter."
"description": "The code to run in the ipython interpreter.",
"type": "string"
}
},
"required": ["code"]
"required": ["code"],
"type": "object"
}
}
},
"type": "function"
},
{
"type": "function",
"function": {
"name": "brave_search",
"description": "Executes a web search with Brave.",
"name": "brave_search",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for."
"description": "The query to search for.",
"type": "string"
}
},
"required": ["query"]
"required": ["query"],
"type": "object"
}
}
},
"type": "function"
},
{
"type": "function",
"function": {
"name": "wolfram_alpha",
"description": "Executes a query with Wolfram Alpha.",
"name": "wolfram_alpha",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to execute."
"description": "The query to execute.",
"type": "string"
}
},
"required": ["query"]
"required": ["query"],
"type": "object"
}
}
},
"type": "function"
},
{
"type": "function",
"function": {
"name": "test",
"description": "Runs a test.",
"name": "test",
"parameters": {
"type": "object",
"properties": {
"condition": {
"type": "boolean",
"description": "The condition to test."
"description": "The condition to test.",
"type": "boolean"
}
},
"required": ["condition"]
"required": ["condition"],
"type": "object"
}
}
},
"type": "function"
}
]
}
Loading

0 comments on commit 7eb5202

Please sign in to comment.