Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.

refactor tp load checkpoint #114

Merged
merged 3 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions energonai/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
from torch import nn, dtype

from colossalai.nn.layer.utils import divide
from colossalai.nn import Linear1D_Col, Linear1D_Row
from energonai.nn import Linear1D_Col, Linear1D_Row

from energonai.utils import get_current_device


class MultiHeadAttention1D(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
bias: bool = True,
dtype: dtype = torch.float16,
max_seq_len: int = 512,
fused_qkv: bool = True,
is_decoder:bool = True
) -> None:
hidden_size: int,
num_heads: int,
bias: bool = True,
dtype: dtype = torch.float16,
max_seq_len: int = 512,
fused_qkv: bool = True,
is_decoder: bool = True
) -> None:
super().__init__()

self.hidden_size = hidden_size
Expand All @@ -37,16 +38,16 @@ def __init__(self,

if is_decoder:
self.causal_mask = torch.tril(torch.ones((max_seq_len, max_seq_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, max_seq_len, max_seq_len).bool()
device=get_current_device())).view(1, 1, max_seq_len, max_seq_len).bool()
self.causal_mask_bias = torch.tensor(-1e4, dtype=dtype, device=get_current_device())

def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3)

def forward(self,
hidden_states,
def forward(self,
hidden_states,
attention_mask=None,
seq_lens=None):

Expand All @@ -68,12 +69,12 @@ def forward(self,
v = self._split_heads(v, num_attention_heads, self.attention_head_size)

hidden_states = torch.matmul(q, k.transpose(-1, -2))
hidden_states = hidden_states / math.sqrt(self.attention_head_size)
hidden_states = hidden_states / math.sqrt(self.attention_head_size)
q_len, k_len = q.size(-2), k.size(-2)

if self.is_decoder:
hidden_states = torch.where(self.causal_mask[: ,: ,0:q_len , 0:k_len], hidden_states, self.causal_mask_bias)
hidden_states = torch.where(self.causal_mask[:, :, 0:q_len, 0:k_len], hidden_states, self.causal_mask_bias)

if attention_mask is not None:
hidden_states = hidden_states + attention_mask
hidden_states = self.softmax(hidden_states)
Expand All @@ -89,7 +90,6 @@ def forward(self,

return hidden_states


# causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
# device=get_current_device())).view(1, 1, q_len, k_len).bool()
# hidden_states = torch.where(causal_mask, hidden_states, torch.tensor(-1e4, dtype=hidden_states.dtype, device=get_current_device()))
# hidden_states = torch.where(causal_mask, hidden_states, torch.tensor(-1e4, dtype=hidden_states.dtype, device=get_current_device()))
10 changes: 5 additions & 5 deletions energonai/model/downstream.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from torch import dtype, nn
from colossalai.nn import Classifier1D
from energonai.nn import Classifier1D


class LMHead1D(nn.Module):
def __init__(self,
hidden_size:int,
vocab_size:int,
hidden_size: int,
vocab_size: int,
word_embedding_weight: nn.Parameter = None,
bias:bool = False,
bias: bool = False,
dtype: dtype = None) -> None:
super().__init__()
self.dense = Classifier1D(hidden_size, vocab_size, word_embedding_weight, bias=bias, dtype=dtype)
Expand All @@ -18,4 +18,4 @@ def weight(self):

def forward(self, x):
x = self.dense(x)
return x
return x
60 changes: 30 additions & 30 deletions energonai/model/endecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,48 @@
import torch
from torch import dtype
from torch import nn
from colossalai.nn import LayerNorm1D
from energonai.nn import LayerNorm1D

from .mlp import MLP1D
from .attention import MultiHeadAttention1D



class Block1D(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
activation: Callable = nn.functional.gelu,
layernorm_epsilon:float = 1e-5,
dtype: dtype = torch.float16,
bias: bool = True,
apply_post_layernorm: bool = False,
max_seq_len: int = 512,
fused_qkv:bool = True,
is_decoder:bool = True) -> None:
hidden_size: int,
num_heads: int,
mlp_ratio: float,
activation: Callable = nn.functional.gelu,
layernorm_epsilon: float = 1e-5,
dtype: dtype = torch.float16,
bias: bool = True,
apply_post_layernorm: bool = False,
max_seq_len: int = 512,
fused_qkv: bool = True,
is_decoder: bool = True) -> None:
super().__init__()

self.apply_post_layernorm = apply_post_layernorm
self.norm1 = LayerNorm1D(hidden_size, eps=layernorm_epsilon)

self.attn = MultiHeadAttention1D(hidden_size = hidden_size,
num_heads = num_heads,
bias = bias,
dtype = dtype,
max_seq_len = max_seq_len,
fused_qkv = fused_qkv,
is_decoder = is_decoder)
self.attn = MultiHeadAttention1D(hidden_size=hidden_size,
num_heads=num_heads,
bias=bias,
dtype=dtype,
max_seq_len=max_seq_len,
fused_qkv=fused_qkv,
is_decoder=is_decoder)

self.norm2 = LayerNorm1D(hidden_size, eps=layernorm_epsilon)

self.mlp = MLP1D(hidden_size = hidden_size,
mlp_ratio = mlp_ratio,
activation = activation,
dtype = dtype,
bias = bias)
self.mlp = MLP1D(hidden_size=hidden_size,
mlp_ratio=mlp_ratio,
activation=activation,
dtype=dtype,
bias=bias)

def forward(self, hidden_states, attention_mask=None, seq_lens=None):

if not self.apply_post_layernorm:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
Expand All @@ -54,11 +54,11 @@ def forward(self, hidden_states, attention_mask=None, seq_lens=None):

if not self.apply_post_layernorm:
residual = hidden_states

hidden_states = self.norm2(hidden_states)

if self.apply_post_layernorm:
residual = hidden_states
hidden_states = residual + self.mlp(hidden_states)
return hidden_states

return hidden_states
4 changes: 2 additions & 2 deletions energonai/model/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable
import torch
from torch import dtype, nn
from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D
from energonai.nn import Linear1D_Col, Linear1D_Row, Classifier1D


class MLP1D(nn.Module):
Expand All @@ -23,4 +23,4 @@ def forward(self, hidden_states):
hidden_states = self.dense_1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.dense_2(hidden_states)
return hidden_states
return hidden_states
4 changes: 2 additions & 2 deletions energonai/model/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .embedding import Embedding1D
from .downstream import LMHead1D

from colossalai.nn import LayerNorm1D
from energonai.nn import LayerNorm1D
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from energonai.utils import is_using_pp, get_current_device
Expand Down Expand Up @@ -158,7 +158,7 @@ def create_pipeline_model(depth: int = 48,
numel = 0
for _, param in model.named_parameters(recurse=True):
numel += param.numel()
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB!!!')
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')

if "checkpoint" in model_kwargs.keys() and "model_name" in model_kwargs.keys():
start = time.time()
Expand Down
51 changes: 31 additions & 20 deletions energonai/utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,37 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
"""
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)

if gpc.get_local_rank(parallel_mode) == 0:

partitioned_state_list = [dict() for _ in range(depth)]

for key in list(state_dict.keys()):
param = state_dict.pop(key)
dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
param = torch.chunk(param, depth, dim=dim)
for i, p in enumerate(partitioned_state_list):
p[key] = param[i] if do_partition else param

else:
partitioned_state_list = [None for _ in range(depth)]

partitioned_state = [None]
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
return partitioned_state[0]
group = gpc.get_cpu_group(parallel_mode)
is_rank0 = gpc.get_local_rank(parallel_mode) == 0
partition_info = [None]
if is_rank0:
partition_info_dict = OrderedDict()
for key, param in state_dict.items():
dim = dims[key]
is_partitioned = partition_states[key]
shape = list(param.shape)
if is_partitioned:
shape[dim] = shape[dim] // depth
partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim)
partition_info[0] = partition_info_dict
dist.broadcast_object_list(partition_info, src_rank, group=group)
partitioned_state = OrderedDict()
for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items():
if is_partitioned:
output = torch.empty(shape, dtype=dtype)
if is_rank0:
scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)]
else:
scatter_list = None
dist.scatter(output, scatter_list, src_rank, group=group)
else:
if is_rank0:
output = state_dict[key]
else:
output = torch.empty(shape, dtype=dtype)
dist.broadcast(output, src_rank, group=group)
partitioned_state[key] = output
return partitioned_state


def gather_tensor_parallel_state_dict(
Expand Down
44 changes: 22 additions & 22 deletions examples/auto_pipeline/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from colossalai.core import global_context as gpc
from energonai.logging import get_dist_logger
from colossalai.nn.layer.utils import divide, ACT2FN
from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D
from colossalai.nn import LayerNorm1D
from energonai.nn import Linear1D_Col, Linear1D_Row, Classifier1D
from energonai.nn import LayerNorm1D
from energonai.kernel import transpose_pad, transpose_depad, depad
from energonai.nn import VocabParallelEmbedding1D
from energonai.utils import get_current_device, is_using_pp
Expand Down Expand Up @@ -54,7 +54,7 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None):
# if position_ids is None:
# position_ids = torch.arange(max_padding_size, dtype=torch.long, device=get_current_device()).unsqueeze(0)

x = self.word_embeddings(input_ids) # + self.position_embeddings(position_ids)
x = self.word_embeddings(input_ids) # + self.position_embeddings(position_ids)

if self.tokentype_embeddings is not None and tokentype_ids is not None:
x = x + self.tokentype_embeddings(tokentype_ids)
Expand Down Expand Up @@ -124,7 +124,7 @@ def forward(self, hidden_states, attention_mask=None):
# if seq_lens is not None:
# sum_seq = torch.sum(seq_lens)
# attention_output = transpose_depad(attention_output, batch_size, sum_seq, max_padding_size, seq_lens,
# num_attention_heads, self.attention_head_size)
# num_attention_heads, self.attention_head_size)
# else:
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()

Expand Down Expand Up @@ -207,8 +207,9 @@ def forward(self, hidden_states, attention_mask):

return hidden_states


class Bert1D(nn.Module):

def __init__(self,
vocab_size: int = 50304,
max_position_embeddings: int = 1024,
Expand All @@ -225,27 +226,27 @@ def __init__(self,
):
super().__init__()
self.embed = BertEmbedding1D(embedding_dim=hidden_size,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
padding_idx=padding_idx,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype)
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
padding_idx=padding_idx,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype)
self.blocks = nn.ModuleList()

for i in range(depth):
self.blocks.append(BertTransformerLayer1D(
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
activation=activation,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype,
bias=bias,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,)
)
hidden_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
activation=activation,
layernorm_epsilon=layernorm_epsilon,
dtype=dtype,
bias=bias,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,)
)

def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):

# batch_size = input_ids.shape[0]
# max_padding_size = input_ids.shape[1]

Expand All @@ -259,7 +260,6 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l
return hidden_states



def _create_bert_model(model_kwargs):
model = Bert1D(**model_kwargs)
return model
Expand Down
8 changes: 4 additions & 4 deletions examples/bert/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from colossalai.core import global_context as gpc
from energonai.logging import get_dist_logger
from colossalai.nn.layer.utils import divide, ACT2FN
from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D
from colossalai.nn import LayerNorm1D
from energonai.nn import Linear1D_Col, Linear1D_Row, Classifier1D
from energonai.nn import LayerNorm1D
from energonai.kernel import transpose_pad, transpose_depad, depad
from energonai.nn import VocabParallelEmbedding1D
from energonai.utils import get_current_device, is_using_pp
Expand Down Expand Up @@ -241,8 +241,8 @@ def __init__(self,
dtype=dtype,
bias=bias,
fuse_scale_mask_softmax=fuse_scale_mask_softmax,
)
)
)
)

def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):

Expand Down
Loading