diff --git a/energon/engine/rpc_worker.py b/energon/engine/rpc_worker.py index 3d94567..3e0c118 100644 --- a/energon/engine/rpc_worker.py +++ b/energon/engine/rpc_worker.py @@ -1,4 +1,5 @@ import os +import time import torch import inspect import torch.distributed.rpc as rpc @@ -14,6 +15,19 @@ "bert": BertPipelineCommWrapper, } +class ReturnDict: + def __init__(self): + self.rd = dict() + + def enqueue(self, key, output): + self.rd[key] = output + + def top(self, key): + while key not in self.rd: + time.sleep(0.001) + output = self.rd[key] + return output + class RPCWorker: def __init__(self, model_class, @@ -31,10 +45,12 @@ def __init__(self, self.model = None # call the model self.rank = gpc.get_local_rank(ParallelMode.GLOBAL) - torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}') - + torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}') self._init_self() + if gpc.is_initialized(ParallelMode.PIPELINE): + self.return_dict = ReturnDict() + def _init_self(self): print("[INFO] init model in rank {}".format(self.rank)) @@ -50,17 +66,22 @@ def _init_self(self): self.model = self.pipe_wrapper(model = self.model, max_batch_size = self.max_batch_size, dtype=self.dtype) def run(self, key, inputs): + # print("key: {}".format(key), flush=True) torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}') for k, v in inputs.items(): if v is not None: inputs[k] = v.cuda() #non_blocking=True if gpc.is_initialized(ParallelMode.PIPELINE): - output = self.model.run(key, inputs) + if gpc.is_last_rank(ParallelMode.PIPELINE): + output, cur_key = self.model.run(key, inputs) + self.return_dict.enqueue(cur_key, output.cpu()) + return self.return_dict.top(key) + else: + self.model.run(key, inputs) + return None else: output = self.model(**inputs) - - if output is not None: - return output.cpu() #non_blocking=True + return output return None \ No newline at end of file