Skip to content

Commit

Permalink
adjust grpc settings, ensure connected before sending any grpc commands
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Feb 28, 2025
1 parent 36a6389 commit 4081305
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
33 changes: 20 additions & 13 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ def __init__(self, _id: str, address: str, desc: str, device_capabilities: Devic
self.channel = None
self.stub = None
self.channel_options = [
("grpc.max_metadata_size", 64 * 1024 * 1024),
("grpc.max_metadata_size", 32 * 1024 * 1024),
("grpc.max_receive_message_length", 256 * 1024 * 1024),
("grpc.max_send_message_length", 256 * 1024 * 1024),
("grpc.max_concurrent_streams", 100),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.keepalive_time_ms", 20000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_time_ms", 10000),
("grpc.keepalive_timeout_ms", 5000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_pings_without_data", 0),
("grpc.http2.min_ping_interval_without_data_ms", 5000),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
]
Expand All @@ -55,14 +56,13 @@ def device_capabilities(self) -> DeviceCapabilities:
return self._device_capabilities

async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)

async def is_connected(self) -> bool:
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
Expand All @@ -74,7 +74,7 @@ async def disconnect(self):
self.stub = None

async def _ensure_connected(self):
if not await self.is_connected():
if not (await self.is_connected()):
try:
await asyncio.wait_for(self.connect(), timeout=10.0)
except asyncio.TimeoutError:
Expand All @@ -98,6 +98,7 @@ async def health_check(self) -> bool:
return False

async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
Expand All @@ -112,6 +113,7 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional
await self.stub.SendPrompt(request)

async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
Expand All @@ -131,6 +133,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
Expand All @@ -153,6 +156,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr
return loss

async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
await self._ensure_connected()
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
Expand All @@ -171,6 +175,7 @@ async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)

async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
await self._ensure_connected()
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
response = await self.stub.CollectTopology(request)
topology = Topology()
Expand All @@ -185,6 +190,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
return topology

async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
await self._ensure_connected()
tensor = None
if isinstance(result, np.ndarray):
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
Expand All @@ -193,8 +199,9 @@ async def send_result(self, request_id: str, result: List[int], is_finished: boo
await self.stub.SendResult(request)

async def send_opaque_status(self, request_id: str, status: str) -> None:
await self._ensure_connected()
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)

def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
proto_inference_state = node_service_pb2.InferenceState()
Expand Down
2 changes: 2 additions & 0 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ async def start(self) -> None:
("grpc.max_concurrent_streams", 100),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_concurrent_streams", 0), # Unlimited concurrent streams
],
)
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
Expand Down

0 comments on commit 4081305

Please sign in to comment.