Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max completion tokens #720

Closed
3 changes: 3 additions & 0 deletions .github/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dic
if not line.startswith('data: '):
continue

if line == 'data: [DONE]':
break

data = json.loads(line[6:]) # Skip 'data: ' prefix
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
print(f"Received content: {content}", flush=True)
Expand Down
101 changes: 71 additions & 30 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.inference.generation_options import GenerationOptions
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
from typing import Callable, Optional
from PIL import Image
Expand Down Expand Up @@ -47,15 +48,20 @@ def to_dict(self):


class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None,
max_completion_tokens: Optional[int] = None):
self.model = model
self.messages = messages
self.temperature = temperature
self.tools = tools
self.max_completion_tokens = max_completion_tokens

def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
return {"model": self.model, "messages": [message.to_dict() for message in self.messages],
"temperature": self.temperature, "tools": self.tools, "max_completion_tokens": self.max_completion_tokens}

def to_generation_options(self) -> GenerationOptions:
return GenerationOptions(max_completion_tokens=self.max_completion_tokens)

def generate_completion(
chat_request: ChatCompletionRequest,
Expand All @@ -67,6 +73,7 @@ def generate_completion(
finish_reason: Union[Literal["length", "stop"], None],
object_type: Literal["chat.completion", "text_completion"],
) -> dict:
decoded_tokens = tokenizer.decode(tokens)
completion = {
"id": f"chatcmpl-{request_id}",
"object": object_type,
Expand All @@ -75,7 +82,6 @@ def generate_completion(
"system_fingerprint": f"exo_{VERSION}",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
"logprobs": None,
"finish_reason": finish_reason,
}],
Expand All @@ -90,10 +96,12 @@ def generate_completion(

choice = completion["choices"][0]
if object_type.startswith("chat.completion"):
key_name = "delta" if stream else "message"
choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
if stream:
choice["delta"] = {"role": "assistant", "content": decoded_tokens} if len(decoded_tokens) > 0 else {}
else:
choice["message"] = {"role": "assistant", "content": decoded_tokens}
elif object_type == "text_completion":
choice["text"] = tokenizer.decode(tokens)
choice["text"] = decoded_tokens
else:
ValueError(f"Unsupported response type: {object_type}")

Expand Down Expand Up @@ -137,7 +145,7 @@ def remap_messages(messages: List[Message]) -> List[Message]:
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools:
if tools:
chat_template_args["tools"] = tools

try:
Expand All @@ -147,7 +155,7 @@ def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]
except UnicodeEncodeError:
# Handle Unicode encoding by ensuring everything is UTF-8
chat_template_args["conversation"] = [
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
for k, v in m.to_dict().items()}
for m in messages
]
Expand All @@ -168,6 +176,9 @@ def parse_chat_request(data: dict, default_model: str):
[parse_message(msg) for msg in data["messages"]],
data.get("temperature", 0.0),
data.get("tools", None),
# The max_tokens field is deprecated, but some clients may still use it, fall back to that value if
# max_completion_tokens is not provided.
data.get("max_completion_tokens", data.get("max_tokens", None)),
)


Expand Down Expand Up @@ -234,7 +245,7 @@ def __init__(
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")

# Always add images route, regardless of compilation status
self.images_dir = get_exo_images_dir()
self.images_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -357,7 +368,12 @@ async def handle_post_chat_completions(self, request):
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")

try:
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(
shard,
prompt,
request_id=request_id,
generation_options=chat_request.to_generation_options()
))), timeout=self.response_timeout)

if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")

Expand Down Expand Up @@ -387,25 +403,48 @@ async def handle_post_chat_completions(self, request):
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")

finish_reason = None
if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")

completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
tokens,
stream,
finish_reason,
"chat.completion",
)

await response.write(f"data: {json.dumps(completion)}\n\n".encode())
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=}")
if is_finished:
if tokens[-1] == eos_token_id:
# We do not return the EOS token in the response
tokens.pop(-1)
finish_reason = "stop"
else:
finish_reason = "length"

if DEBUG >= 2: print(f"{finish_reason=}")

if len(tokens) > 0:
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
tokens,
stream,
None,
"chat.completion",
)

await response.write(f"data: {json.dumps(completion)}\n\n".encode())

if is_finished:
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
[],
stream,
finish_reason,
"chat.completion",
)

await response.write(f"data: {json.dumps(completion)}\n\n".encode())
break

# Send the DONE event when the stream is finished
await response.write(b"data: [DONE]\n\n")
await response.write_eof()
return response

Expand All @@ -414,7 +453,7 @@ async def handle_post_chat_completions(self, request):
return web.json_response({"detail": "Response generation timed out"}, status=408)

except Exception as e:
if DEBUG >= 2:
if DEBUG >= 2:
print(f"[ChatGPTAPI] Error processing prompt: {e}")
traceback.print_exc()
return web.json_response(
Expand All @@ -440,6 +479,8 @@ async def handle_post_chat_completions(self, request):
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
# We do not return the EOS token in the response
tokens.pop(-1)
finish_reason = "stop"

return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
Expand Down Expand Up @@ -501,22 +542,22 @@ async def stream_image(_request_id: str, result, is_finished: bool):
image_filename = f"{_request_id}.png"
image_path = self.images_dir/image_filename
im.save(image_path)

# Get URL for the saved image
try:
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
full_image_url = base_url + str(image_url)

await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
except KeyError as e:
if DEBUG >= 2: print(f"Error getting image URL: {e}")
# Fallback to direct file path if URL generation fails
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')

if is_finished:
await response.write_eof()

except Exception as e:
if DEBUG >= 2: print(f"Error processing image: {e}")
if DEBUG >= 2: traceback.print_exc()
Expand Down
8 changes: 8 additions & 0 deletions exo/inference/generation_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Optional


class GenerationOptions:
max_completion_tokens: Optional[int] = None

def __init__(self, max_completion_tokens: Optional[int] = None):
self.max_completion_tokens = max_completion_tokens
17 changes: 13 additions & 4 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import asyncio
from typing import Optional, Tuple, List

from exo.inference.generation_options import GenerationOptions

from . import node_service_pb2
from . import node_service_pb2_grpc

Expand Down Expand Up @@ -97,7 +99,7 @@ async def health_check(self) -> bool:
traceback.print_exc()
return False

async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None, generation_options: Optional[GenerationOptions] = None) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
Expand All @@ -107,11 +109,12 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state),
generation_options=None if generation_options is None else self.serialize_generation_options(generation_options)
)
await self.stub.SendPrompt(request)

async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None, generation_options: Optional[GenerationOptions] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
Expand All @@ -121,7 +124,8 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state),
generation_options=None if generation_options is None else self.serialize_generation_options(generation_options)
)
response = await self.stub.SendTensor(request)

Expand Down Expand Up @@ -217,3 +221,8 @@ def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.I
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

def serialize_generation_options(self, generation_options: GenerationOptions) -> node_service_pb2.GenerationOptions:
return node_service_pb2.GenerationOptions(
max_completion_tokens=generation_options.max_completion_tokens
)
13 changes: 11 additions & 2 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import platform

from exo.inference.generation_options import GenerationOptions

from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
Expand Down Expand Up @@ -67,7 +69,8 @@ async def SendPrompt(self, request, context):
prompt = request.prompt
request_id = request.request_id
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
generation_options = None if request.generation_options is None else self.deserialize_generation_options(request.generation_options)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state, generation_options)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
Expand All @@ -83,8 +86,9 @@ async def SendTensor(self, request, context):
request_id = request.request_id

inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
generation_options = None if request.generation_options is None else self.deserialize_generation_options(request.generation_options)

result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
result = await self.node.process_tensor(shard, tensor, request_id, inference_state, generation_options)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
Expand Down Expand Up @@ -169,3 +173,8 @@ def deserialize_inference_state(self, inference_state_proto: node_service_pb2.In
inference_state.update(other_data)

return inference_state

def deserialize_generation_options(self, generation_options_proto: node_service_pb2.GenerationOptions) -> GenerationOptions:
return GenerationOptions(
max_completion_tokens=generation_options_proto.max_completion_tokens if generation_options_proto.HasField("max_completion_tokens") else None
)
6 changes: 6 additions & 0 deletions exo/networking/grpc/node_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,24 @@ message Shard {
int32 n_layers = 4;
}

message GenerationOptions {
optional int32 max_completion_tokens = 1;
}

message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
optional GenerationOptions generation_options = 5;
}

message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
optional GenerationOptions generation_options = 5;
}

message ExampleRequest {
Expand Down
Loading