Skip to content

Commit

Permalink
[colotensor] use cpu memory to store state_dict (#1367)
Browse files Browse the repository at this point in the history
  • Loading branch information
1SAA authored Jul 26, 2022
1 parent 943a963 commit 87775a0
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
3 changes: 2 additions & 1 deletion colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
self.chunk_manager.access_chunk(chunk)
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
destination[prefix + name] = fp32_p.clone() if keep_vars else fp32_p.clone().detach()
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
for chunk in chunks:
self.chunk_manager.release_chunk(chunk)
for name, buf in self.named_buffers():
Expand Down
18 changes: 16 additions & 2 deletions colossalai/utils/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern


def robust_broadcast(tensor):
with torch.no_grad():
is_cpu_ten = tensor.device.type == 'cpu'
if is_cpu_ten:
b_data = tensor.cuda()
else:
b_data = tensor

dist.broadcast(b_data, 0)

if is_cpu_ten:
tensor.copy_(b_data)


def gather_tensor(colo_tensor: ColoTensor) -> None:
"""Make colo_tensor replicated when the rank is 0
"""
Expand All @@ -27,15 +41,15 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
"""Reversal operation of `gather_tensor`.
"""
if dist_spec.placement == DistPlacementPattern.REPLICATE:
dist.broadcast(colo_tensor.data, 0)
robust_broadcast(colo_tensor.data)
else:
global_size = colo_tensor.size_global()

if dist.get_rank() == 0:
entire_data = colo_tensor.data
else:
entire_data = torch.empty(global_size, device=colo_tensor.device)
dist.broadcast(entire_data, 0)
robust_broadcast(entire_data)

if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_ddp/test_ddp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
assert k1 == k2
assert torch.allclose(t1, t2, atol=1e-3, rtol=1e-3)

if t1.device != t2.device:
temp_t2 = t2.to(t1.device)
else:
temp_t2 = t2

assert torch.allclose(t1, temp_t2, atol=1e-3, rtol=1e-3)


def init_ddp(module: torch.nn.Module) -> ColoDDP:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor/test_colo_checkpoint_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def run_dist(rank, world_size, port, dp_degree, tp_degree):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
x = torch.randn(4, 4, device=get_current_device())
x = torch.randn(4, 4)
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
param.set_tensor_spec(*spec)
Expand Down

0 comments on commit 87775a0

Please sign in to comment.