From 00def03f36ed2375780da77a8ffdbd5522c1b7f4 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Tue, 15 Feb 2022 10:37:21 +0800 Subject: [PATCH] fp16 support --- energon/engine/engine.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/energon/engine/engine.py b/energon/engine/engine.py index 4b8b2b5..22acef4 100644 --- a/energon/engine/engine.py +++ b/energon/engine/engine.py @@ -49,10 +49,6 @@ def __init__(self, # self._save_parameter() self._load_parameter() - if self._dtype: - self._dtype_convert() - - def _init_dist(self): launch_from_torch(tp_size = self._tp_size, pp_size = self._pp_size) @@ -63,7 +59,13 @@ def _set_sample_device(self): self._samples[k] = v.cuda() def _init_model(self): - model = self._model_class(**self._model_config).cuda() + """ + TODO(dujiangsu) support other dtype + """ + if self._dtype == torch.half: + model = self._model_class(**self._model_config).cuda().half() + else: + model = self._model_class(**self._model_config).cuda() model.eval() self._model = PipelineCommWrapper(model = model, sample = self._samples, dtype=self._dtype) @@ -96,7 +98,7 @@ def _get_ranks_name(self): ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}' return ranks_name - def _dtype_convert(self): + def dtype_convert(self): """ TODO(dujiangsu) support other dtype """