Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidated ChatGPT API improvements: Improve Compatibility, add requests specific token limits, and textual stop sequences #734

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
75c0598
Strip EOS token from output to mirror the OpenAI behaviour
joshuacoles Feb 19, 2025
cf40bd6
Add "[DONE]" message and change streaming response to mirror OpenAI
joshuacoles Feb 19, 2025
244b8ed
Fix bench.py to support "[DONE]" terminating event in stream
joshuacoles Feb 20, 2025
2c2e907
Resolve out of range error in debug line
joshuacoles Feb 20, 2025
6c23ba0
Add generation options protocol buf definition and corresponding pyth…
joshuacoles Feb 20, 2025
929503d
Fish `generation_options` through from the ChatGPT API request to `pr…
joshuacoles Feb 20, 2025
c5e001a
Apply the generation options to inference
joshuacoles Feb 20, 2025
4323a2b
Emit the finish reason completion chunk separately from the content t…
joshuacoles Feb 20, 2025
605a60f
Add `stop` to the ChatGPT API and `GenerationOptions`
joshuacoles Feb 20, 2025
1dbae1e
Refactor `process_inference_result`
joshuacoles Feb 20, 2025
76a6716
Add finish_reason to the `SendResultRequest` proto type
joshuacoles Feb 20, 2025
93622ce
Pipe finish reason around the place
joshuacoles Feb 20, 2025
f92d7c2
Add finish_reason extraction to ChatGPT API
joshuacoles Feb 20, 2025
b9f2009
Add missing parameter to update_topology_viz
joshuacoles Feb 20, 2025
37cad05
Move finish reason determination to the process_inference_result func…
joshuacoles Feb 20, 2025
cea5f68
Move to BufferedOutput object
joshuacoles Feb 20, 2025
7b8b52f
Add `BufferedOutput#token_count`
joshuacoles Feb 20, 2025
cb85946
Move next token determination to `BufferedOutput`
joshuacoles Feb 20, 2025
4a85723
Delay emission by a number of tokens to simulate keeping a buffer aro…
joshuacoles Feb 20, 2025
4678a94
Return None from `process_inference_result` as it is unused.
joshuacoles Feb 20, 2025
d737fe8
Resolve issue with stop parameters in GRPC communication
joshuacoles Feb 20, 2025
24985ed
Skip empty completions before finish as caused by buffering
joshuacoles Feb 20, 2025
d49b58b
Resolve issue with shared array being used for BufferedOutput
joshuacoles Feb 20, 2025
286335c
Move more logic into BufferedOutput
joshuacoles Feb 20, 2025
25a91bf
Move to keeping the text and tokens so that we can search for stop se…
joshuacoles Feb 20, 2025
c5b501a
Initial stop sequence work
joshuacoles Feb 20, 2025
1c6f7c1
Search for stop sequences in the entire string
joshuacoles Feb 20, 2025
1857aad
Handle partial tokens and retokenise when tokens are split
joshuacoles Feb 20, 2025
0328f5e
Buffer by character count rather than token count
joshuacoles Feb 20, 2025
90b8069
Handle no stop sequences being provided
joshuacoles Feb 20, 2025
3731605
Fix issue with finish_reason not being defined in stable diffusion br…
joshuacoles Feb 21, 2025
e88e37d
Add tests for the different ChatGPT API features introduced in this PR
joshuacoles Mar 7, 2025
2f4b1ac
Add test for immediate stop sequence and fix issues resulting from th…
joshuacoles Mar 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dic
if not line.startswith('data: '):
continue

if line == 'data: [DONE]':
break

data = json.loads(line[6:]) # Skip 'data: ' prefix
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
print(f"Received content: {content}", flush=True)
Expand Down
124 changes: 84 additions & 40 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.inference.generation_options import GenerationOptions
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
from typing import Callable, Optional
from PIL import Image
Expand Down Expand Up @@ -47,15 +48,22 @@ def to_dict(self):


class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None,
max_completion_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None):
self.model = model
self.messages = messages
self.temperature = temperature
self.tools = tools
self.max_completion_tokens = max_completion_tokens
self.stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else None

def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
return {"model": self.model, "messages": [message.to_dict() for message in self.messages],
"temperature": self.temperature, "tools": self.tools, "max_completion_tokens": self.max_completion_tokens,
"stop": self.stop}

def to_generation_options(self) -> GenerationOptions:
return GenerationOptions(max_completion_tokens=self.max_completion_tokens, stop=self.stop)

def generate_completion(
chat_request: ChatCompletionRequest,
Expand All @@ -67,6 +75,7 @@ def generate_completion(
finish_reason: Union[Literal["length", "stop"], None],
object_type: Literal["chat.completion", "text_completion"],
) -> dict:
decoded_tokens = tokenizer.decode(tokens)
completion = {
"id": f"chatcmpl-{request_id}",
"object": object_type,
Expand All @@ -75,7 +84,6 @@ def generate_completion(
"system_fingerprint": f"exo_{VERSION}",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
"logprobs": None,
"finish_reason": finish_reason,
}],
Expand All @@ -90,10 +98,12 @@ def generate_completion(

choice = completion["choices"][0]
if object_type.startswith("chat.completion"):
key_name = "delta" if stream else "message"
choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
if stream:
choice["delta"] = {"role": "assistant", "content": decoded_tokens} if len(decoded_tokens) > 0 else {}
else:
choice["message"] = {"role": "assistant", "content": decoded_tokens}
elif object_type == "text_completion":
choice["text"] = tokenizer.decode(tokens)
choice["text"] = decoded_tokens
else:
ValueError(f"Unsupported response type: {object_type}")

Expand Down Expand Up @@ -137,7 +147,7 @@ def remap_messages(messages: List[Message]) -> List[Message]:
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools:
if tools:
chat_template_args["tools"] = tools

try:
Expand All @@ -147,7 +157,7 @@ def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]
except UnicodeEncodeError:
# Handle Unicode encoding by ensuring everything is UTF-8
chat_template_args["conversation"] = [
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
for k, v in m.to_dict().items()}
for m in messages
]
Expand All @@ -168,6 +178,10 @@ def parse_chat_request(data: dict, default_model: str):
[parse_message(msg) for msg in data["messages"]],
data.get("temperature", 0.0),
data.get("tools", None),
# The max_tokens field is deprecated, but some clients may still use it, fall back to that value if
# max_completion_tokens is not provided.
data.get("max_completion_tokens", data.get("max_tokens", None)),
data.get("stop", None),
)


Expand Down Expand Up @@ -201,7 +215,7 @@ def __init__(

# Get the callback system and register our handler
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
self.token_callback.on_next(lambda _request_id, tokens, is_finished, finish_reason: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished, finish_reason)))
self.system_prompt = system_prompt

cors = aiohttp_cors.setup(self.app)
Expand Down Expand Up @@ -234,7 +248,7 @@ def __init__(
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")

# Always add images route, regardless of compilation status
self.images_dir = get_exo_images_dir()
self.images_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -357,7 +371,12 @@ async def handle_post_chat_completions(self, request):
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")

try:
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(
shard,
prompt,
request_id=request_id,
generation_options=chat_request.to_generation_options()
))), timeout=self.response_timeout)

if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")

Expand All @@ -376,36 +395,59 @@ async def handle_post_chat_completions(self, request):
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
tokens, is_finished = await asyncio.wait_for(
tokens, is_finished, finish_reason = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=} {finish_reason=}")

eos_token_id = None
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")

finish_reason = None
if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")

completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
tokens,
stream,
finish_reason,
"chat.completion",
)
if len(tokens) == 0 and not is_finished:
continue

if len(tokens) > 0:
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=}")
if is_finished:
if tokens[-1] == eos_token_id:
# We do not return the EOS token in the response
tokens.pop(-1)

if DEBUG >= 2: print(f"{finish_reason=}")

if len(tokens) > 0:
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
tokens,
stream,
None,
"chat.completion",
)

await response.write(f"data: {json.dumps(completion)}\n\n".encode())
await response.write(f"data: {json.dumps(completion)}\n\n".encode())

if is_finished:
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
[],
stream,
finish_reason,
"chat.completion",
)

await response.write(f"data: {json.dumps(completion)}\n\n".encode())
break

# Send the DONE event when the stream is finished
await response.write(b"data: [DONE]\n\n")
await response.write_eof()
return response

Expand All @@ -414,7 +456,7 @@ async def handle_post_chat_completions(self, request):
return web.json_response({"detail": "Response generation timed out"}, status=408)

except Exception as e:
if DEBUG >= 2:
if DEBUG >= 2:
print(f"[ChatGPTAPI] Error processing prompt: {e}")
traceback.print_exc()
return web.json_response(
Expand All @@ -430,17 +472,19 @@ async def handle_post_chat_completions(self, request):
else:
tokens = []
while True:
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
_tokens, is_finished, finish_reason = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
tokens.extend(_tokens)
if is_finished:
break
finish_reason = "length"

eos_token_id = None
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
finish_reason = "stop"
if len(tokens) > 0:
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
# We do not return the EOS token in the response
tokens.pop(-1)

return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
except asyncio.TimeoutError:
Expand Down Expand Up @@ -501,22 +545,22 @@ async def stream_image(_request_id: str, result, is_finished: bool):
image_filename = f"{_request_id}.png"
image_path = self.images_dir/image_filename
im.save(image_path)

# Get URL for the saved image
try:
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
full_image_url = base_url + str(image_url)

await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
except KeyError as e:
if DEBUG >= 2: print(f"Error getting image URL: {e}")
# Fallback to direct file path if URL generation fails
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')

if is_finished:
await response.write_eof()

except Exception as e:
if DEBUG >= 2: print(f"Error processing image: {e}")
if DEBUG >= 2: traceback.print_exc()
Expand Down Expand Up @@ -620,8 +664,8 @@ async def handle_get_topology(self, request):
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)

async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
await self.token_queues[request_id].put((tokens, is_finished))
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool, finish_reason: Optional[str] = None):
await self.token_queues[request_id].put((tokens, is_finished, finish_reason))

async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
Expand Down
Loading