Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deepseek r1 tool call example polyfill (template newly adds trailing <think>) #52

Merged
merged 6 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,25 @@ class chat_template {
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
if (full.find(prefix) != 0) {
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
tool_call_example_ = full.substr(prefix.size());
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
Expand Down
67 changes: 38 additions & 29 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def raise_exception(message: str):
raise ValueError(message)


def tojson(eval_ctx, value, indent=None):
return json.dumps(value, indent=indent)

TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')


Expand Down Expand Up @@ -114,16 +111,22 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext
# print(out, file=sys.stderr)
return out
except BaseException as e:
# print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True)
# print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True)
return ""

def __init__(self, template, known_eos_tokens, env=None):
def __init__(self, template, env=None, filters=None, global_functions=None):
if not env:
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols]
)
if filters:
for name, func in filters.items():
env.filters[name] = func
if global_functions:
for name, func in global_functions.items():
env.globals[name] = func
self.env = env
self.template = env.from_string(template)

Expand Down Expand Up @@ -243,15 +246,24 @@ def make_tool_call(tool_name, arguments):
}
prefix = self.try_raw_render([user_msg], add_generation_prompt=True)
full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False)
if not full.startswith(prefix):
for known_eos_token in known_eos_tokens:
prefix = prefix.rstrip()
if prefix.endswith(known_eos_token):
prefix = prefix[:-len(known_eos_token)]
break
if not full.startswith(prefix):

common_prefix_length = 0
for i in range(min(len(prefix), len(full))):
if prefix[i] != full[i]:
break
if prefix[i] == '<':
# DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
# but it removes thinking tags for past messages.
# The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
continue
common_prefix_length = i + 1

example = full[common_prefix_length:]
if "tool_name" not in example and "some_value" not in example:
print("Failed to infer a tool call example (possible template bug)", file=sys.stderr)
self.tool_call_example = full[len(prefix):]
else:
self.tool_call_example = example

except Exception as e:
print(f"Failed to generate tool call example: {e}", file=sys.stderr)

Expand Down Expand Up @@ -321,7 +333,11 @@ def apply(self, context):
message['content'] = [{"type": "text", "text": message['content']}]

try:
return self.template.render(**context)
out = self.template.render(**context)
out = out.replace("\\u0027", "'")
out = out.replace('&#34;', '"')
out = out.replace('&#39;', "'")
return out
except Exception as e1:
for message in context['messages']:
if message.get("content") is None:
Expand Down Expand Up @@ -350,21 +366,14 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c
async with aiofiles.open(template_file, 'w') as f:
await f.write(template_src)

known_eos_tokens = [
"<|END_OF_TURN_TOKEN|>",
"<end_of_turn>",
"</s>",
"<|im_end|>",
"<|eom_id|>",
"<|eot_id|>",
"<|end▁of▁sentence|>",
]

template = chat_template(template_src, known_eos_tokens)
template.env.filters['safe'] = lambda x: x
template.env.filters['tojson'] = tojson
template.env.globals['raise_exception'] = raise_exception
template.env.globals['strftime_now'] = strftime_now
template = chat_template(template_src,
filters={
'safe': lambda x: x,
},
global_functions={
'raise_exception': raise_exception,
'strftime_now': strftime_now,
})
caps = template.original_caps

if not context_files:
Expand Down
8 changes: 3 additions & 5 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ target_link_libraries(test-polyfills PRIVATE
)
if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
endif()

if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
gtest_discover_tests(test-polyfills)
add_test(NAME test-polyfills COMMAND test-polyfills)
set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endif()

add_executable(test-capabilities test-capabilities.cpp)
Expand Down Expand Up @@ -82,6 +79,7 @@ set(MODEL_IDS
MiniMaxAI/MiniMax-Text-01
indischepartij/MiniCPM-3B-OpenHermes-2.5-v2
mattshumer/Reflection-Llama-3.1-70B
meetkai/functionary-medium-v3.1
meetkai/functionary-medium-v3.2
meta-llama/Llama-3.1-8B-Instruct # Gated
meta-llama/Llama-3.2-3B-Instruct # Gated
Expand Down
3 changes: 2 additions & 1 deletion tests/contexts/simple.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
],
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>"
"eos_token": "<|endoftext|>",
"tools_in_user_message": false
}
3 changes: 2 additions & 1 deletion tests/contexts/system.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
],
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>"
"eos_token": "<|endoftext|>",
"tools_in_user_message": false
}
57 changes: 29 additions & 28 deletions tests/contexts/tool_use.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"add_generation_prompt": true,
"bos_token": "<|startoftext|>",
"eos_token": "<|endoftext|>",
"tools_in_user_message": false,
"builtin_tools": [
"wolfram_alpha",
"brave_search"
Expand All @@ -96,72 +97,72 @@
"todays_date": "2024-09-03",
"tools": [
{
"type": "function",
"function": {
"name": "ipython",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"name": "ipython",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to run in the ipython interpreter."
"description": "The code to run in the ipython interpreter.",
"type": "string"
}
},
"required": ["code"]
"required": ["code"],
"type": "object"
}
}
},
"type": "function"
},
{
"type": "function",
"function": {
"name": "brave_search",
"description": "Executes a web search with Brave.",
"name": "brave_search",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for."
"description": "The query to search for.",
"type": "string"
}
},
"required": ["query"]
"required": ["query"],
"type": "object"
}
}
},
"type": "function"
},
{
"type": "function",
"function": {
"name": "wolfram_alpha",
"description": "Executes a query with Wolfram Alpha.",
"name": "wolfram_alpha",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to execute."
"description": "The query to execute.",
"type": "string"
}
},
"required": ["query"]
"required": ["query"],
"type": "object"
}
}
},
"type": "function"
},
{
"type": "function",
"function": {
"name": "test",
"description": "Runs a test.",
"name": "test",
"parameters": {
"type": "object",
"properties": {
"condition": {
"type": "boolean",
"description": "The condition to test."
"description": "The condition to test.",
"type": "boolean"
}
},
"required": ["condition"]
"required": ["condition"],
"type": "object"
}
}
},
"type": "function"
}
]
}
Loading