Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.

rpc_worker, retturn results in order #42

Merged
merged 1 commit into from
May 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
import torch
import inspect
import torch.distributed.rpc as rpc
Expand All @@ -14,6 +15,19 @@
"bert": BertPipelineCommWrapper,
}

class ReturnDict:
def __init__(self):
self.rd = dict()

def enqueue(self, key, output):
self.rd[key] = output

def top(self, key):
while key not in self.rd:
time.sleep(0.001)
output = self.rd[key]
return output

class RPCWorker:
def __init__(self,
model_class,
Expand All @@ -31,10 +45,12 @@ def __init__(self,
self.model = None # call the model
self.rank = gpc.get_local_rank(ParallelMode.GLOBAL)

torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')

torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')
self._init_self()

if gpc.is_initialized(ParallelMode.PIPELINE):
self.return_dict = ReturnDict()

def _init_self(self):
print("[INFO] init model in rank {}".format(self.rank))

Expand All @@ -50,17 +66,22 @@ def _init_self(self):
self.model = self.pipe_wrapper(model = self.model, max_batch_size = self.max_batch_size, dtype=self.dtype)

def run(self, key, inputs):
# print("key: {}".format(key), flush=True)
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')
for k, v in inputs.items():
if v is not None:
inputs[k] = v.cuda() #non_blocking=True

if gpc.is_initialized(ParallelMode.PIPELINE):
output = self.model.run(key, inputs)
if gpc.is_last_rank(ParallelMode.PIPELINE):
output, cur_key = self.model.run(key, inputs)
self.return_dict.enqueue(cur_key, output.cpu())
return self.return_dict.top(key)
else:
self.model.run(key, inputs)
return None
else:
output = self.model(**inputs)

if output is not None:
return output.cpu() #non_blocking=True
return output

return None