diff --git a/.github/bench.py b/.github/bench.py index 9bc52e894..554468d90 100644 --- a/.github/bench.py +++ b/.github/bench.py @@ -265,6 +265,9 @@ async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dic if not line.startswith('data: '): continue + if line == 'data: [DONE]': + break + data = json.loads(line[6:]) # Skip 'data: ' prefix if content := data.get('choices', [{}])[0].get('delta', {}).get('content'): print(f"Received content: {content}", flush=True) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 1020fdbc3..5618baa8e 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -14,6 +14,7 @@ from exo.helpers import PrefixDict, shutdown, get_exo_images_dir from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration import Node +from exo.inference.generation_options import GenerationOptions from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name from typing import Callable, Optional from PIL import Image @@ -47,15 +48,22 @@ def to_dict(self): class ChatCompletionRequest: - def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None): + def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None, + max_completion_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None): self.model = model self.messages = messages self.temperature = temperature self.tools = tools + self.max_completion_tokens = max_completion_tokens + self.stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else None def to_dict(self): - return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools} + return {"model": self.model, "messages": [message.to_dict() for message in self.messages], + "temperature": self.temperature, "tools": self.tools, "max_completion_tokens": self.max_completion_tokens, + "stop": self.stop} + def to_generation_options(self) -> GenerationOptions: + return GenerationOptions(max_completion_tokens=self.max_completion_tokens, stop=self.stop) def generate_completion( chat_request: ChatCompletionRequest, @@ -67,6 +75,7 @@ def generate_completion( finish_reason: Union[Literal["length", "stop"], None], object_type: Literal["chat.completion", "text_completion"], ) -> dict: + decoded_tokens = tokenizer.decode(tokens) completion = { "id": f"chatcmpl-{request_id}", "object": object_type, @@ -75,7 +84,6 @@ def generate_completion( "system_fingerprint": f"exo_{VERSION}", "choices": [{ "index": 0, - "message": {"role": "assistant", "content": tokenizer.decode(tokens)}, "logprobs": None, "finish_reason": finish_reason, }], @@ -90,10 +98,12 @@ def generate_completion( choice = completion["choices"][0] if object_type.startswith("chat.completion"): - key_name = "delta" if stream else "message" - choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)} + if stream: + choice["delta"] = {"role": "assistant", "content": decoded_tokens} if len(decoded_tokens) > 0 else {} + else: + choice["message"] = {"role": "assistant", "content": decoded_tokens} elif object_type == "text_completion": - choice["text"] = tokenizer.decode(tokens) + choice["text"] = decoded_tokens else: ValueError(f"Unsupported response type: {object_type}") @@ -137,7 +147,7 @@ def remap_messages(messages: List[Message]) -> List[Message]: def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None): messages = remap_messages(_messages) chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True} - if tools: + if tools: chat_template_args["tools"] = tools try: @@ -147,7 +157,7 @@ def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict] except UnicodeEncodeError: # Handle Unicode encoding by ensuring everything is UTF-8 chat_template_args["conversation"] = [ - {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v + {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v for k, v in m.to_dict().items()} for m in messages ] @@ -168,6 +178,10 @@ def parse_chat_request(data: dict, default_model: str): [parse_message(msg) for msg in data["messages"]], data.get("temperature", 0.0), data.get("tools", None), + # The max_tokens field is deprecated, but some clients may still use it, fall back to that value if + # max_completion_tokens is not provided. + data.get("max_completion_tokens", data.get("max_tokens", None)), + data.get("stop", None), ) @@ -201,7 +215,7 @@ def __init__( # Get the callback system and register our handler self.token_callback = node.on_token.register("chatgpt-api-token-handler") - self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished))) + self.token_callback.on_next(lambda _request_id, tokens, is_finished, finish_reason: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished, finish_reason))) self.system_prompt = system_prompt cors = aiohttp_cors.setup(self.app) @@ -234,7 +248,7 @@ def __init__( self.static_dir = Path(__file__).parent.parent/"tinychat" self.app.router.add_get("/", self.handle_root) self.app.router.add_static("/", self.static_dir, name="static") - + # Always add images route, regardless of compilation status self.images_dir = get_exo_images_dir() self.images_dir.mkdir(parents=True, exist_ok=True) @@ -357,7 +371,12 @@ async def handle_post_chat_completions(self, request): if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}") try: - await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout) + await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt( + shard, + prompt, + request_id=request_id, + generation_options=chat_request.to_generation_options() + ))), timeout=self.response_timeout) if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s") @@ -376,36 +395,59 @@ async def handle_post_chat_completions(self, request): # Stream tokens while waiting for inference to complete while True: if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}") - tokens, is_finished = await asyncio.wait_for( + tokens, is_finished, finish_reason = await asyncio.wait_for( self.token_queues[request_id].get(), timeout=self.response_timeout ) - if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}") + if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=} {finish_reason=}") eos_token_id = None if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") - finish_reason = None - if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length" - if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}") - - completion = generate_completion( - chat_request, - tokenizer, - prompt, - request_id, - tokens, - stream, - finish_reason, - "chat.completion", - ) + if len(tokens) == 0 and not is_finished: + continue + + if len(tokens) > 0: + if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=}") + if is_finished: + if tokens[-1] == eos_token_id: + # We do not return the EOS token in the response + tokens.pop(-1) + + if DEBUG >= 2: print(f"{finish_reason=}") + + if len(tokens) > 0: + completion = generate_completion( + chat_request, + tokenizer, + prompt, + request_id, + tokens, + stream, + None, + "chat.completion", + ) - await response.write(f"data: {json.dumps(completion)}\n\n".encode()) + await response.write(f"data: {json.dumps(completion)}\n\n".encode()) if is_finished: + completion = generate_completion( + chat_request, + tokenizer, + prompt, + request_id, + [], + stream, + finish_reason, + "chat.completion", + ) + + await response.write(f"data: {json.dumps(completion)}\n\n".encode()) break + # Send the DONE event when the stream is finished + await response.write(b"data: [DONE]\n\n") await response.write_eof() return response @@ -414,7 +456,7 @@ async def handle_post_chat_completions(self, request): return web.json_response({"detail": "Response generation timed out"}, status=408) except Exception as e: - if DEBUG >= 2: + if DEBUG >= 2: print(f"[ChatGPTAPI] Error processing prompt: {e}") traceback.print_exc() return web.json_response( @@ -430,17 +472,19 @@ async def handle_post_chat_completions(self, request): else: tokens = [] while True: - _tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout) + _tokens, is_finished, finish_reason = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout) tokens.extend(_tokens) if is_finished: break - finish_reason = "length" + eos_token_id = None if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") - if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}") - if tokens[-1] == eos_token_id: - finish_reason = "stop" + if len(tokens) > 0: + if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}") + if tokens[-1] == eos_token_id: + # We do not return the EOS token in the response + tokens.pop(-1) return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion")) except asyncio.TimeoutError: @@ -501,22 +545,22 @@ async def stream_image(_request_id: str, result, is_finished: bool): image_filename = f"{_request_id}.png" image_path = self.images_dir/image_filename im.save(image_path) - + # Get URL for the saved image try: image_url = request.app.router['static_images'].url_for(filename=image_filename) base_url = f"{request.scheme}://{request.host}" full_image_url = base_url + str(image_url) - + await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') except KeyError as e: if DEBUG >= 2: print(f"Error getting image URL: {e}") # Fallback to direct file path if URL generation fails await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') - + if is_finished: await response.write_eof() - + except Exception as e: if DEBUG >= 2: print(f"Error processing image: {e}") if DEBUG >= 2: traceback.print_exc() @@ -620,8 +664,8 @@ async def handle_get_topology(self, request): if DEBUG >= 2: traceback.print_exc() return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500) - async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool): - await self.token_queues[request_id].put((tokens, is_finished)) + async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool, finish_reason: Optional[str] = None): + await self.token_queues[request_id].put((tokens, is_finished, finish_reason)) async def run(self, host: str = "0.0.0.0", port: int = 52415): runner = web.AppRunner(self.app) diff --git a/exo/api/test_chatgptapi.py b/exo/api/test_chatgptapi.py new file mode 100644 index 000000000..a53d5fcb1 --- /dev/null +++ b/exo/api/test_chatgptapi.py @@ -0,0 +1,213 @@ +import pytest +from openai import OpenAI, AsyncOpenAI +import aiohttp +import json + +# Test configuration +API_BASE_URL = "http://localhost:52415/v1/" +TEST_MODEL = "llama-3.2-1b" + + +@pytest.fixture +def client(): + return OpenAI( + base_url=API_BASE_URL, + api_key="sk-1111" + ) + + +@pytest.fixture +def async_client(): + return AsyncOpenAI( + base_url=API_BASE_URL, + api_key="sk-1111" + ) + + +@pytest.mark.asyncio +async def test_basic_chat_completion(client): + """Test basic non-streaming chat completion""" + response = client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Say 'Hello world'"}], + temperature=0.0 + ) + + assert response.id.startswith("chatcmpl-") + assert response.object == "chat.completion" + assert response.model == TEST_MODEL + assert len(response.choices) == 1 + assert response.choices[0].finish_reason == "stop" + assert "Hello" in response.choices[0].message.content + + +@pytest.mark.asyncio +async def test_streaming_chat_completion(async_client): + """Test streaming chat completion""" + stream = await async_client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Count to 5 separated by commas"}], + temperature=0.0, + stream=True + ) + + responses = [] + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.content: + responses.append(delta.content) + + full_response = "".join(responses) + assert full_response.count(",") == 4 # "1, 2, 3, 4, 5" + assert chunk.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_max_completion_tokens(client): + """Test max_completion_tokens and max_tokens fallback""" + # Test max_completion_tokens + response1 = client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Repeat 'foo bar' 10 times"}], + temperature=0.0, + max_completion_tokens=5 + ) + + # Test max_tokens fallback + response2 = client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Repeat 'foo bar' 10 times"}], + temperature=0.0, + max_tokens=5 # Deprecated parameter fallback + ) + + for response in [response1, response2]: + assert response.choices[0].finish_reason == "length" + assert response.usage.completion_tokens <= 5 + + +@pytest.mark.asyncio +async def test_stop_sequences(client): + """Test stop sequence handling""" + response = client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Complete this sequence directly: A B C"}], + temperature=0.0, + stop=["D"], + max_completion_tokens=20 + ) + + content = response.choices[0].message.content + assert "D" not in content + assert response.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_raw_http_request(): + """Test API using raw HTTP request for basic completion""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"{API_BASE_URL}chat/completions", + json={ + "model": TEST_MODEL, + "messages": [{"role": "user", "content": "2+2="}], + "temperature": 0.0 + } + ) as resp: + data = await resp.json() + assert "4" in data["choices"][0]["message"]["content"] + + +@pytest.mark.asyncio +async def test_raw_http_request_streaming(): + """Test API using raw HTTP request for streaming""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"{API_BASE_URL}chat/completions", + json={ + "model": TEST_MODEL, + "messages": [{"role": "user", + "content": "Count to 3 from 1, separate the numbers with commas and a space only."}], + "temperature": 0.0, + "stream": True + } + ) as resp: + data_lines = [] + async for line in resp.content: + if line.startswith(b"data: "): + data_lines.append(line[6:]) + + # Verify last line is DONE + assert data_lines[-1].strip() == b"[DONE]" + + # Process all but last line which is DONE + chunks = [] + for line in data_lines[:-1]: + chunk = json.loads(line) + if chunk["choices"][0]["delta"].get("content"): + chunks.append(chunk["choices"][0]["delta"]["content"]) + + result = "".join(chunks) + assert "1, 2, 3" in result + + +@pytest.mark.asyncio +async def test_raw_http_no_eot_id(): + """Test that responses don't include the <|eot_id|> special token. Note this token is model dependent.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"{API_BASE_URL}chat/completions", + json={ + "model": TEST_MODEL, + "messages": [{"role": "user", "content": "Say exactly this: <|eot_id|>"}], + "temperature": 0.0 + } + ) as resp: + data = await resp.json() + content = data["choices"][0]["message"]["content"] + # Should return the literal text without the special token + assert "<|eot_id|>" not in content + # Verify the response is properly sanitized + assert "eot_id" not in content.lower() + + +@pytest.mark.asyncio +async def test_stop_sequence_first_token(client): + """Test stop sequence when it's the first generated token""" + response = client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Please repeat the word supercalifragilisticexpialidocious"}], + temperature=0.0, + stop=["T", "sup"], + max_completion_tokens=20 + ) + + content = response.choices[0].message.content + assert content == "" + assert response.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_stop_sequence_first_token_streaming(async_client): + """Test stop sequence handling when first generated token matches stop sequence (streaming)""" + stream = await async_client.chat.completions.create( + model=TEST_MODEL, + messages=[{"role": "user", "content": "Please repeat the word supercalifragilisticexpialidocious"}], + temperature=0.0, + stop=["T", "sup"], + max_completion_tokens=20, + stream=True + ) + + content = [] + finish_reason = None + async for chunk in stream: + if chunk.choices[0].delta.content: + content.append(chunk.choices[0].delta.content) + if chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + + # Should either get empty content with stop reason, + # or content that doesn't contain the stop sequence + assert "".join(content) == "" + assert finish_reason == "stop" diff --git a/exo/inference/generation_options.py b/exo/inference/generation_options.py new file mode 100644 index 000000000..727bbe639 --- /dev/null +++ b/exo/inference/generation_options.py @@ -0,0 +1,12 @@ +from typing import Optional, List + + +class GenerationOptions: + max_completion_tokens: Optional[int] = None + + # Textual stop sequences that will halt generation when encountered + stop: Optional[List[str]] = None + + def __init__(self, max_completion_tokens: Optional[int] = None, stop: Optional[List[str]] = None): + self.max_completion_tokens = max_completion_tokens + self.stop = stop diff --git a/exo/main.py b/exo/main.py index db91251c8..786252fa6 100644 --- a/exo/main.py +++ b/exo/main.py @@ -173,7 +173,7 @@ def configure_uvloop(): system_prompt=args.system_prompt ) buffered_token_output = {} -def update_topology_viz(req_id, tokens, __): +def update_topology_viz(req_id, tokens, __, ___): if not topology_viz: return if not node.inference_engine.shard: return if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 997d43ef1..3b8cdb81c 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -3,6 +3,8 @@ import asyncio from typing import Optional, Tuple, List +from exo.inference.generation_options import GenerationOptions + from . import node_service_pb2 from . import node_service_pb2_grpc @@ -97,7 +99,7 @@ async def health_check(self) -> bool: traceback.print_exc() return False - async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]: + async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None, generation_options: Optional[GenerationOptions] = None) -> Optional[np.array]: await self._ensure_connected() request = node_service_pb2.PromptRequest( prompt=prompt, @@ -108,11 +110,12 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional n_layers=shard.n_layers, ), request_id=request_id, - inference_state=None if inference_state is None else self.serialize_inference_state(inference_state) + inference_state=None if inference_state is None else self.serialize_inference_state(inference_state), + generation_options=None if generation_options is None else self.serialize_generation_options(generation_options) ) await self.stub.SendPrompt(request) - async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]: + async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None, generation_options: Optional[GenerationOptions] = None) -> Optional[np.array]: await self._ensure_connected() request = node_service_pb2.TensorRequest( shard=node_service_pb2.Shard( @@ -123,7 +126,8 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O ), tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), request_id=request_id, - inference_state=None if inference_state is None else self.serialize_inference_state(inference_state) + inference_state=None if inference_state is None else self.serialize_inference_state(inference_state), + generation_options=None if generation_options is None else self.serialize_generation_options(generation_options) ) response = await self.stub.SendTensor(request) @@ -189,13 +193,13 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: topology.add_edge(node_id, conn.to_id, conn.description) return topology - async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + async def send_result(self, request_id: str, result: List[int], is_finished: bool, finish_reason: Optional[str] = None) -> None: await self._ensure_connected() tensor = None if isinstance(result, np.ndarray): tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype)) result = [] - request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished) + request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished, finish_reason=finish_reason) await self.stub.SendResult(request) async def send_opaque_status(self, request_id: str, status: str) -> None: @@ -224,3 +228,9 @@ def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.I if other_data: proto_inference_state.other_data_json = json.dumps(other_data) return proto_inference_state + + def serialize_generation_options(self, generation_options: GenerationOptions) -> node_service_pb2.GenerationOptions: + return node_service_pb2.GenerationOptions( + max_completion_tokens=generation_options.max_completion_tokens, + stop=generation_options.stop + ) diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index cbc33f949..5b238d6f3 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -5,6 +5,8 @@ import platform +from exo.inference.generation_options import GenerationOptions + from . import node_service_pb2 from . import node_service_pb2_grpc from exo import DEBUG @@ -69,7 +71,8 @@ async def SendPrompt(self, request, context): prompt = request.prompt request_id = request.request_id inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state) - result = await self.node.process_prompt(shard, prompt, request_id, inference_state) + generation_options = None if request.generation_options is None else self.deserialize_generation_options(request.generation_options) + result = await self.node.process_prompt(shard, prompt, request_id, inference_state, generation_options) if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}") tensor_data = result.tobytes() if result is not None else None return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() @@ -85,8 +88,9 @@ async def SendTensor(self, request, context): request_id = request.request_id inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state) + generation_options = None if request.generation_options is None else self.deserialize_generation_options(request.generation_options) - result = await self.node.process_tensor(shard, tensor, request_id, inference_state) + result = await self.node.process_tensor(shard, tensor, request_id, inference_state, generation_options) if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}") tensor_data = result.tobytes() if result is not None else None return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() @@ -138,12 +142,13 @@ async def SendResult(self, request, context): request_id = request.request_id result = request.result is_finished = request.is_finished + finish_reason = request.finish_reason img = request.tensor - if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}") + if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=} {finish_reason=}") result = list(result) if len(img.tensor_data) > 0: result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape) - self.node.on_token.trigger_all(request_id, result, is_finished) + self.node.on_token.trigger_all(request_id, result, is_finished, finish_reason) return node_service_pb2.Empty() async def SendOpaqueStatus(self, request, context): @@ -171,3 +176,9 @@ def deserialize_inference_state(self, inference_state_proto: node_service_pb2.In inference_state.update(other_data) return inference_state + + def deserialize_generation_options(self, generation_options_proto: node_service_pb2.GenerationOptions) -> GenerationOptions: + return GenerationOptions( + max_completion_tokens=generation_options_proto.max_completion_tokens if generation_options_proto.HasField("max_completion_tokens") else None, + stop=generation_options_proto.stop, + ) diff --git a/exo/networking/grpc/node_service.proto b/exo/networking/grpc/node_service.proto index 882a5247f..bd836de57 100644 --- a/exo/networking/grpc/node_service.proto +++ b/exo/networking/grpc/node_service.proto @@ -19,11 +19,17 @@ message Shard { int32 n_layers = 4; } +message GenerationOptions { + optional int32 max_completion_tokens = 1; + repeated string stop = 2; +} + message PromptRequest { Shard shard = 1; string prompt = 2; optional string request_id = 3; optional InferenceState inference_state = 4; + optional GenerationOptions generation_options = 5; } message TensorRequest { @@ -31,6 +37,7 @@ message TensorRequest { Tensor tensor = 2; optional string request_id = 3; optional InferenceState inference_state = 4; + optional GenerationOptions generation_options = 5; } message ExampleRequest { @@ -100,6 +107,7 @@ message SendResultRequest { repeated int32 result = 2; optional Tensor tensor = 3; bool is_finished = 4; + optional string finish_reason = 5; } message SendOpaqueStatusRequest { diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py index 9a83380dd..16637366e 100644 --- a/exo/networking/grpc/node_service_pb2.py +++ b/exo/networking/grpc/node_service_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x97\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"_\n\x11GenerationOptions\x12\"\n\x15max_completion_tokens\x18\x01 \x01(\x05H\x00\x88\x01\x01\x12\x0c\n\x04stop\x18\x02 \x03(\tB\x18\n\x16_max_completion_tokens\"\x94\x02\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x12@\n\x12generation_options\x18\x05 \x01(\x0b\x32\x1f.node_service.GenerationOptionsH\x02\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_stateB\x15\n\x13_generation_options\"\xaa\x02\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x12@\n\x12generation_options\x18\x05 \x01(\x0b\x32\x1f.node_service.GenerationOptionsH\x02\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_stateB\x15\n\x13_generation_options\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\xb0\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x12\x1a\n\rfinish_reason\x18\x05 \x01(\tH\x01\x88\x01\x01\x42\t\n\x07_tensorB\x10\n\x0e_finish_reason\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x97\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -41,50 +41,52 @@ _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001' _globals['_SHARD']._serialized_start=36 _globals['_SHARD']._serialized_end=119 - _globals['_PROMPTREQUEST']._serialized_start=122 - _globals['_PROMPTREQUEST']._serialized_end=309 - _globals['_TENSORREQUEST']._serialized_start=312 - _globals['_TENSORREQUEST']._serialized_end=521 - _globals['_EXAMPLEREQUEST']._serialized_start=524 - _globals['_EXAMPLEREQUEST']._serialized_end=746 - _globals['_LOSS']._serialized_start=748 - _globals['_LOSS']._serialized_end=820 - _globals['_TENSOR']._serialized_start=822 - _globals['_TENSOR']._serialized_end=881 - _globals['_TENSORLIST']._serialized_start=883 - _globals['_TENSORLIST']._serialized_end=934 - _globals['_INFERENCESTATE']._serialized_start=937 - _globals['_INFERENCESTATE']._serialized_end=1275 - _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123 - _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194 - _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196 - _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277 - _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337 - _globals['_TOPOLOGY']._serialized_start=1340 - _globals['_TOPOLOGY']._serialized_end=1620 - _globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461 - _globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541 - _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620 - _globals['_PEERCONNECTION']._serialized_start=1622 - _globals['_PEERCONNECTION']._serialized_end=1695 - _globals['_PEERCONNECTIONS']._serialized_start=1697 - _globals['_PEERCONNECTIONS']._serialized_end=1765 - _globals['_DEVICEFLOPS']._serialized_start=1767 - _globals['_DEVICEFLOPS']._serialized_end=1822 - _globals['_DEVICECAPABILITIES']._serialized_start=1824 - _globals['_DEVICECAPABILITIES']._serialized_end=1931 - _globals['_SENDRESULTREQUEST']._serialized_start=1934 - _globals['_SENDRESULTREQUEST']._serialized_end=2064 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2066 - _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2127 - _globals['_HEALTHCHECKREQUEST']._serialized_start=2129 - _globals['_HEALTHCHECKREQUEST']._serialized_end=2149 - _globals['_HEALTHCHECKRESPONSE']._serialized_start=2151 - _globals['_HEALTHCHECKRESPONSE']._serialized_end=2192 - _globals['_EMPTY']._serialized_start=2194 - _globals['_EMPTY']._serialized_end=2201 - _globals['_NODESERVICE']._serialized_start=2204 - _globals['_NODESERVICE']._serialized_end=2739 + _globals['_GENERATIONOPTIONS']._serialized_start=121 + _globals['_GENERATIONOPTIONS']._serialized_end=216 + _globals['_PROMPTREQUEST']._serialized_start=219 + _globals['_PROMPTREQUEST']._serialized_end=495 + _globals['_TENSORREQUEST']._serialized_start=498 + _globals['_TENSORREQUEST']._serialized_end=796 + _globals['_EXAMPLEREQUEST']._serialized_start=799 + _globals['_EXAMPLEREQUEST']._serialized_end=1021 + _globals['_LOSS']._serialized_start=1023 + _globals['_LOSS']._serialized_end=1095 + _globals['_TENSOR']._serialized_start=1097 + _globals['_TENSOR']._serialized_end=1156 + _globals['_TENSORLIST']._serialized_start=1158 + _globals['_TENSORLIST']._serialized_end=1209 + _globals['_INFERENCESTATE']._serialized_start=1212 + _globals['_INFERENCESTATE']._serialized_end=1550 + _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1398 + _globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1469 + _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1471 + _globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1550 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1552 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1612 + _globals['_TOPOLOGY']._serialized_start=1615 + _globals['_TOPOLOGY']._serialized_end=1895 + _globals['_TOPOLOGY_NODESENTRY']._serialized_start=1736 + _globals['_TOPOLOGY_NODESENTRY']._serialized_end=1814 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1816 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1895 + _globals['_PEERCONNECTION']._serialized_start=1897 + _globals['_PEERCONNECTION']._serialized_end=1970 + _globals['_PEERCONNECTIONS']._serialized_start=1972 + _globals['_PEERCONNECTIONS']._serialized_end=2040 + _globals['_DEVICEFLOPS']._serialized_start=2042 + _globals['_DEVICEFLOPS']._serialized_end=2097 + _globals['_DEVICECAPABILITIES']._serialized_start=2099 + _globals['_DEVICECAPABILITIES']._serialized_end=2206 + _globals['_SENDRESULTREQUEST']._serialized_start=2209 + _globals['_SENDRESULTREQUEST']._serialized_end=2385 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2387 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2448 + _globals['_HEALTHCHECKREQUEST']._serialized_start=2450 + _globals['_HEALTHCHECKREQUEST']._serialized_end=2470 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=2472 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=2513 + _globals['_EMPTY']._serialized_start=2515 + _globals['_EMPTY']._serialized_end=2522 + _globals['_NODESERVICE']._serialized_start=2525 + _globals['_NODESERVICE']._serialized_end=3060 # @@protoc_insertion_point(module_scope) diff --git a/exo/networking/peer_handle.py b/exo/networking/peer_handle.py index d75318efc..3cf7b5a7b 100644 --- a/exo/networking/peer_handle.py +++ b/exo/networking/peer_handle.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, List import numpy as np from exo.inference.shard import Shard +from exo.inference.generation_options import GenerationOptions from exo.topology.device_capabilities import DeviceCapabilities from exo.topology.topology import Topology @@ -40,11 +41,11 @@ async def health_check(self) -> bool: pass @abstractmethod - async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]: + async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, generation_options: Optional[GenerationOptions] = None) -> Optional[np.array]: pass @abstractmethod - async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]: + async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, generation_options: Optional[GenerationOptions] = None) -> Optional[np.array]: pass @abstractmethod diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 1281aa8ae..628073225 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -10,6 +10,7 @@ from exo.topology.topology import Topology from exo.topology.device_capabilities import device_capabilities, UNKNOWN_DEVICE_CAPABILITIES from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards +from exo.inference.generation_options import GenerationOptions from exo import DEBUG from exo.helpers import AsyncCallbackSystem from exo.viz.topology_viz import TopologyViz @@ -17,6 +18,109 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.download.shard_download import ShardDownloader +class BufferedOutput: + stop_sequences: List[str] + max_tokens: int + eos_token_id: int + buffer_char_size: int + + _token_count: int = 0 + buffer: List[Tuple[int, str]] + + is_finished: bool = False + finish_reason: Optional[str] = None + + def __init__(self, max_tokens: int, eos_token_id: int, stop_sequences: List[str], tokenizer): + self.buffer = [] + self.buffer_char_size = max(len(stop_sequence) for stop_sequence in stop_sequences) if len(stop_sequences) > 0 else 0 + self.max_tokens = max_tokens + self.eos_token_id = eos_token_id + self.stop_sequences = stop_sequences + self.tokenizer = tokenizer + + def append(self, token: int): + self.buffer.append((token, self.tokenizer.decode([token]))) + self._token_count += 1 + + if token == self.eos_token_id: + self.is_finished = True + self.finish_reason = "stop" + elif self._token_count >= self.max_tokens: + self.is_finished = True + self.finish_reason = "length" + elif len(self.stop_sequences) > 0: + self.attempt_to_match_stop_sequences() + + def assembled_text(self) -> str: + return "".join([text for _, text in self.buffer]) + + def attempt_to_match_stop_sequences(self): + assembled_text = self.assembled_text() + if DEBUG >= 2: print(f"Attempting to match stop sequences against: {assembled_text=}") + + for stop_sequence in self.stop_sequences: + if len(assembled_text) < len(stop_sequence): + continue + + if DEBUG >= 2: print(f"Checking if {assembled_text=} matches {stop_sequence=}") + + if stop_sequence in assembled_text: + if DEBUG >= 2: print(f"Match found: {assembled_text=} matches {stop_sequence=}") + + # Find character index where stop sequence starts + char_idx = assembled_text.index(stop_sequence) + + # Find which token contains this character index and where in that token the sequence starts + current_char_pos = 0 + tokens_to_keep = 0 + for _, text in self.buffer: + next_char_pos = current_char_pos + len(text) + if current_char_pos <= char_idx < next_char_pos: + # Found the token containing the stop sequence + # Get the text before the stop sequence + token_offset = char_idx - current_char_pos + truncated_text = text[:token_offset] + + # This is a little bit of a hack as the SendResults GRPC call expects tokens so to return the truncated text + # we need to retokenize it. This is not ideal as it means we are not returning the exact tokens that were + # generated by the model. However, it is the simplest way to handle this case. + + # Retokenize the truncated text + new_tokens = self.tokenizer.encode(truncated_text, add_special_tokens=False) + + # Replace the final token with the retokenized truncated text + self.buffer = self.buffer[:tokens_to_keep] + for token in new_tokens: + self.buffer.append((token, self.tokenizer.decode([token]))) + break + + current_char_pos = next_char_pos + tokens_to_keep += 1 + else: + # If we didn't find the token, just keep everything up to char_idx + self.buffer = self.buffer[:tokens_to_keep] + + self.is_finished = True + self.finish_reason = "stop" + break + + def token_count(self) -> int: + return self._token_count + + def next_tokens(self) -> List[int]: + if self.is_finished: + # Return all remaining tokens if finished + tokens = [token for token, _ in self.buffer] + self.buffer = [] + return tokens + elif len(self.assembled_text()) >= self.buffer_char_size: + token, _ = self.buffer.pop(0) + return [token] + + # Not enough tokens yet + return [] + + class Node: def __init__( self, @@ -39,16 +143,16 @@ def __init__( self.peers: List[PeerHandle] = {} self.topology: Topology = Topology() self.device_capabilities = UNKNOWN_DEVICE_CAPABILITIES - self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {} + self.buffered_token_output: Dict[str, BufferedOutput] = {} self.buffered_logits: Dict[str, List[np.ndarray]] = {} self.buffered_inputs: Dict[str, List[np.ndarray]] = {} self.buffered_partials: Dict[str, List[np.ndarray]] = {} self.checkpoints: Dict[str, Dict[str, int]] = {} - + self.max_generate_tokens = max_generate_tokens self.topology_viz = topology_viz self.default_sample_temperature = default_sample_temperature - self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]() + self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool, Optional[str]]]() self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]() self._on_opaque_status.register("node_status").on_next(self.on_node_status) self.node_download_progress: Dict[str, RepoProgressEvent] = {} @@ -110,49 +214,95 @@ async def broadcast_supported_engines(self, supported_engines_names: List[str]): def get_topology_inference_engines(self) -> List[List[str]]: return self.topology_inference_engines_pool - - token_count = 0 - first_token_time = 0 + async def process_inference_result( self, shard, result: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None, + generation_options: Optional[GenerationOptions] = None, ): - if shard.model_id != 'stable-diffusion-2-1-base': - if request_id not in self.buffered_token_output: - self.buffered_token_output[request_id] = ([], False) - is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - if shard.is_last_layer() and not is_finished: - token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) - await self.inference_engine.ensure_shard(shard) - self.buffered_token_output[request_id][0].append(token.item()) - is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") - forward = token.reshape(1, -1) - intermediate_result = [self.buffered_token_output[request_id][0][-1]] - else: - forward = result + await self.inference_engine.ensure_shard(shard) + + if shard.model_id == 'stable-diffusion-2-1-base': + # Stable Diffusion specific processing. This will mutate inference_state. + forward, intermediate_result, is_finished = await self.handle_stable_diffusion_inference( + result, inference_state + ) + + # We don't do finish reason determination here for stable diffusion + finish_reason = None else: - await self.inference_engine.ensure_shard(shard) - is_finished = inference_state.get("is_finished", False) - intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result) - forward = result + # LLM specific processing + forward, intermediate_result, is_finished, finish_reason = await self.handle_llm_inference( + shard, result, request_id, generation_options + ) + + # Yield the intermediate result before continuing further generation (for LLMs this will be the next token in the + # output). if shard.is_last_layer(): - self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished) - asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished)) + if intermediate_result is None: + raise ValueError("Intermediate result is None") + self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished, finish_reason) + asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished, finish_reason)) + + # Common completion handling if is_finished: - if shard.model_id != 'stable-diffusion-2-1-base': - self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) - self.outstanding_requests.pop(request_id) + self.buffered_token_output.pop(request_id, None) + self.outstanding_requests.pop(request_id, None) else: self.outstanding_requests[request_id] = "waiting" - asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state)) + asyncio.create_task(self.forward_tensor( + shard, forward, request_id, self.get_partition_index(offset=1), + inference_state, generation_options + )) + + if shard.model_id == 'stable-diffusion-2-1-base': + return intermediate_result + else: + return None + + async def handle_llm_inference(self, shard: Shard, result, request_id: str, generation_options: GenerationOptions): + """Handle LLM-specific inference results processing""" + if not shard.is_last_layer(): + return result, None, False, None + + if request_id not in self.buffered_token_output: + max_tokens = self.max_generate_tokens + if generation_options and generation_options.max_completion_tokens: + max_tokens = min(max_tokens, generation_options.max_completion_tokens) + + stop_sequences = generation_options.stop or [] + self.buffered_token_output[request_id] = BufferedOutput( + eos_token_id=self.inference_engine.tokenizer.eos_token_id, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + tokenizer=self.inference_engine.tokenizer, + ) + + buffered_output = self.buffered_token_output[request_id] + + token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) + buffered_output.append(token.item()) + + if DEBUG >= 2: + print(f"[{request_id}] LLM result size: {result.size}, finished: {buffered_output.is_finished}, tokens: {buffered_output.token_count()}, finish_reason: {buffered_output.finish_reason}") + + return token.reshape(1, -1), buffered_output.next_tokens(), buffered_output.is_finished, buffered_output.finish_reason - return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result + async def handle_stable_diffusion_inference(self, result, inference_state): + """Handle Stable Diffusion-specific inference results processing""" + is_finished = inference_state.get("is_finished", False) + if inference_state['is_step_finished']: + inference_state['step'] += 1 + + if inference_state['step'] == inference_state['total_steps']: + is_finished = True + + return result, result, is_finished async def process_prompt( self, @@ -160,6 +310,7 @@ async def process_prompt( prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = {}, + generation_options: Optional[GenerationOptions] = None, ) -> Optional[np.ndarray]: shard = self.get_current_shard(base_shard) start_time = time.perf_counter_ns() @@ -178,7 +329,7 @@ async def process_prompt( ) ) start_time = time.perf_counter_ns() - resp = await self._process_prompt(base_shard, prompt, request_id, inference_state) + resp = await self._process_prompt(base_shard, prompt, request_id, inference_state, generation_options) end_time = time.perf_counter_ns() elapsed_time_ns = end_time - start_time asyncio.create_task( @@ -198,7 +349,9 @@ async def process_prompt( ) if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}") - async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]: + async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, + inference_state: Optional[dict] = None, + generation_options: Optional[GenerationOptions] = None) -> Optional[np.ndarray]: if request_id is None: request_id = str(uuid.uuid4()) shard = self.get_current_shard(base_shard) @@ -207,19 +360,19 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Opti if not shard.is_first_layer(): if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}") self.outstanding_requests[request_id] = "waiting" - resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state) + resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state, generation_options) return None else: self.outstanding_requests[request_id] = "processing" result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state) - ret = await self.process_inference_result(shard, result, request_id, inference_state) + ret = await self.process_inference_result(shard, result, request_id, inference_state, generation_options) return result async def enqueue_example( self, base_shard: Shard, example: np.ndarray, - target: np.ndarray, + target: np.ndarray, length: np.ndarray, request_id: Optional[str] = None, train: bool = False, @@ -232,7 +385,7 @@ async def enqueue_example( if request_id is None: request_id = str(uuid.uuid4()) self.outstanding_requests[request_id] = "waiting" - loss = await self.forward_example(shard, example, target, length, train, request_id, 0) + loss = await self.forward_example(shard, example, target, length, train, request_id, 0) return loss async def coordinate_save( @@ -263,7 +416,7 @@ async def process_example( self, base_shard: Shard, example: np.ndarray, - target: np.ndarray, + target: np.ndarray, length: np.ndarray, train: bool = False, request_id: Optional[str] = None, @@ -308,7 +461,7 @@ async def _process_example( self, base_shard: Shard, example: np.ndarray, - target: np.ndarray, + target: np.ndarray, length: np.ndarray, train: bool = False, request_id: Optional[str] = None, @@ -351,17 +504,18 @@ async def _process_example( print(f"Error processing example for shard {shard}: {e}") traceback.print_exc() return None - + async def process_tensor( self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None, + generation_options: Optional[GenerationOptions] = None, ) -> Optional[np.ndarray]: shard = self.get_current_shard(base_shard) start_time = time.perf_counter_ns() - resp = await self._process_tensor(shard, tensor, request_id, inference_state) + resp = await self._process_tensor(shard, tensor, request_id, inference_state, generation_options) end_time = time.perf_counter_ns() elapsed_time_ns = end_time - start_time if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}") @@ -372,6 +526,7 @@ async def _process_tensor( tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[dict] = None, + generation_options: Optional[GenerationOptions] = None, ) -> Optional[np.ndarray]: if request_id is None: request_id = str(uuid.uuid4()) @@ -380,13 +535,13 @@ async def _process_tensor( try: self.outstanding_requests[request_id] = "processing" result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state) - ret = await self.process_inference_result(shard, result, request_id, inference_state) + ret = await self.process_inference_result(shard, result, request_id, inference_state, generation_options) return ret except Exception as e: self.outstanding_requests.pop(request_id) print(f"Error processing tensor for shard {shard}: {e}") traceback.print_exc() - + async def forward_example( self, base_shard: Shard, @@ -415,20 +570,21 @@ async def forward_prompt( request_id: str, target_index: int, inference_state: Optional[dict] = None, + generation_options: Optional[GenerationOptions] = None, ) -> None: if DEBUG >= 1: print(f"target partition index: {target_index}") target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id next_shard = self.get_current_shard(base_shard, target_index) if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}") if target_id == self.id: - await self.process_prompt(next_shard, prompt, request_id, inference_state) + await self.process_prompt(next_shard, prompt, request_id, inference_state, generation_options) else: target_peer = next((p for p in self.peers if p.id() == target_id), None) if not target_peer: raise ValueError(f"Peer for {target_index} not found") if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}") - await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state) - + await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state, generation_options=generation_options) + async def forward_tensor( self, base_shard: Shard, @@ -436,19 +592,20 @@ async def forward_tensor( request_id: str, target_index: int, inference_state: Optional[dict] = None, + generation_options: Optional[GenerationOptions] = None, ) -> None: if DEBUG >= 1: print(f"target partition index: {target_index}") target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id next_shard = self.get_current_shard(base_shard, target_index) if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}") if target_id == self.id: - await self.process_tensor(next_shard, tensor, request_id, inference_state) + await self.process_tensor(next_shard, tensor, request_id, inference_state, generation_options) else: target_peer = next((p for p in self.peers if p.id() == target_id), None) if not target_peer: raise ValueError(f"Peer for {target_index} not found") if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}") - await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state) + await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state, generation_options=generation_options) def get_partition_index(self, offset: int = 0): if not self.partitioning_strategy: @@ -574,22 +731,22 @@ async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topol return self.topology @property - def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: + def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool, Optional[str]]]: return self._on_token @property def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: return self._on_opaque_status - def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None: - if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}") - self.on_token.trigger_all(request_id, tokens, is_finished) - - async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None: - if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=}") + def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool, finish_reason: Optional[str] = None) -> None: + if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=} {finish_reason=}") + self.on_token.trigger_all(request_id, tokens, is_finished, finish_reason) + + async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool, finish_reason: Optional[str] = None) -> None: + if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=} {finish_reason=}") async def send_result_to_peer(peer): try: - await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0) + await asyncio.wait_for(peer.send_result(request_id, result, is_finished, finish_reason), timeout=15.0) except asyncio.TimeoutError: print(f"Timeout broadcasting result to {peer.id()}") except Exception as e: @@ -617,12 +774,3 @@ async def send_status_to_peer(peer): @property def current_topology(self) -> Topology: return self.topology - - def handle_stable_diffusion(self, inference_state, result): - if inference_state['is_step_finished']: - inference_state['step']+=1 - progress = [inference_state['step'],inference_state['total_steps']] - intermediate_result = result - if progress[0] == progress[1]: - intermediate_result = result - return intermediate_result, inference_state