Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm int8 integration into Megatron part of metaseq #10

Open
wants to merge 4 commits into
base: fairseq_v3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from .quantization_utils import quantization_init
from megatron import get_args, get_global_memory_buffer
# from megatron.model.fused_bias_gelu import bias_gelu, bias_gelu_back


_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
Expand Down Expand Up @@ -263,7 +263,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel, apply_pre_gelu=False, apply_pre_ln=False):
async_grad_allreduce, sequence_parallel, apply_pre_gelu=False, apply_pre_ln=False, q_linear=None):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
Expand Down Expand Up @@ -292,9 +292,13 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
if ctx.apply_pre_gelu:
total_input = gelu(total_input)

output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
if q_linear is not None:
# if q_linear is passed, regular weight will be ignored
output = q_linear(total_input)
else:
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output

@staticmethod
Expand Down Expand Up @@ -377,7 +381,6 @@ def backward(ctx, grad_output):

return grad_input, grad_weight, grad_bias, None, None, None, None


class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.

Expand Down Expand Up @@ -427,7 +430,10 @@ def __init__(self, input_size, output_size, bias=True, gather_output=True,
# we allocate the transpose.
# Initialize weight.
# args = get_args()
if use_cpu_initialization:
self.q_linear = quantization_init(self, self.input_size, self.output_size_per_partition, dtype)
if self.q_linear is not None:
pass # do nothing, quantization_init will take care of it
elif use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=dtype
Expand Down Expand Up @@ -493,7 +499,8 @@ def forward(self, input_):
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce, self.sequence_parallel, False)
self.async_tensor_model_parallel_allreduce, self.sequence_parallel, False, False,
self.q_linear)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel
Expand Down Expand Up @@ -558,7 +565,10 @@ def __init__(self, input_size, output_size, bias=True,
# we allocate the transpose.
# Initialize weight.
# args = get_args()
if use_cpu_initialization:
self.q_linear = quantization_init(self, self.input_size_per_partition, self.output_size, dtype)
if self.q_linear is not None:
pass # do nothing, quantization_init will take care of it
elif use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=dtype,
Expand Down Expand Up @@ -601,7 +611,8 @@ def forward(self, input_):
# Matrix multiply.
output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None, None, self.apply_pre_gelu)
self.gradient_accumulation_fusion, None, None, self.apply_pre_gelu, False,
self.q_linear)
# All-reduce across all the partitions.
if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
Expand Down Expand Up @@ -667,4 +678,4 @@ def forward(self, input_):
# else:
# output = output_
# output_bias = self.bias
# return output, output_bias
# return output, output_bias
90 changes: 90 additions & 0 deletions megatron/mpu/quantization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch

QUANTIZATION_LEVEL = 0 # 0 == None, 1 == LLMint8, 2 == Smoothquant W8A16, 3 == Smoothquant W8A8
QUANTIZATION_IS_LOAD_STATE_DICT = True # Only flip to False for benchmarking purposes if not loading state dict
Comment on lines +3 to +4
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set QUANTIZATION_LEVEL here for different quantization types


if QUANTIZATION_LEVEL == 0:
pass # no quantization, do nothing
elif QUANTIZATION_LEVEL == 1:
from bitsandbytes.nn import Linear8bitLt, Int8Params
elif QUANTIZATION_LEVEL == 2:
from torch_int.nn.linear import W8A16Linear
elif QUANTIZATION_LEVEL == 3:
raise Exception("Quantization level 3 currently not supported")
from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear
# Put in W8A8 stuff here eventually
else:
raise Exception("This quantization level is not supported")

def create_llmint8_linear(weight, bias=None, has_fp16_weights=False, threshold=6.0, index=None):
"""
weight: any kind of weight is fine. fp32, bf16, or fp16. We assume this is a CPU weight, to be converted to int8 upon sending to GPU.
From Tim Dettmers: "when cuda() or to(device) is called, Int8Param class should intercept, cast the CPU weight to fp16 and do the transformation to int8 and then cast it to device/cuda."
bias: the actual bias tensor. Can also be fp32, bf16, or f16 (optional)
Other arg explanations TBD
"""
output_features, input_features = weight.shape
has_bias = bias is not None
q_linear = Linear8bitLt(input_features, output_features, bias=has_bias, has_fp16_weights=has_fp16_weights, threshold=threshold, index=index)
q_linear.weight = weight
return q_linear

def quantized_inference_post_hook(module, incompatible_keys=None):
"""
Holds configs on what linear to create.
Decides on what linear to use, and sends it to cuda in the correct way so you don't blow up your GPU memory.
"""
# we assume that the weight has been put on CPU and we disabled the initialization
has_bias = module.bias is not None
if has_bias:
raise Exception("Int8 conversion currently does not support bias.")

if QUANTIZATION_LEVEL == 1:
module.weight = Int8Params(data=module.weight, has_fp16_weights=False, requires_grad=False) # on CPU
# recommended threshold is 6.0, but can tweak. see llm_int8 paper for how to set
module.q_linear = create_llmint8_linear(module.weight, bias=None, has_fp16_weights=False, threshold=6.0, index=None)
module.q_linear.to(torch.cuda.current_device()) # send it over and get int8!
elif QUANTIZATION_LEVEL == 2:
output_features, input_features = module.weight.shape
# create a temporary linear for W8A16Linear to latch onto that's empty
temp_linear = torch.nn.Linear(input_features, output_features, bias=has_bias, device="meta")
temp_linear.weight = module.weight
temp_linear = temp_linear.cuda()
module.q_linear = W8A16Linear.from_float(temp_linear)
module.q_linear.dequant_type = module.dtype
# clean up old weight
del temp_linear
module.weight = None
else:
raise Exception("Other quantization levels not currently supported")

def quantized_inference_pre_hook(module, state_dict=None, prefix=None, local_metadata=None, strict=None, missing_keys=None, unexpected_keys=None, error_msgs=None):
"""
Create module weight right before state dict load so you don't blow up CPU memory.
CPU weight created here will be immediately moved to GPU post load_state_dict on this particular weight, so won't hang around in CPU long.
Unnecessary args are just to match the _register_load_state_dict_pre_hook method signature
"""
module.weight = torch.nn.Parameter(torch.empty((module.quantized_output_size, module.quantized_input_size), requires_grad=False, dtype=module.dtype, device="cpu"))

def quantization_init(module, input_size, output_size, dtype):
"""
If not quantizing, returns None. Otherwise, returns True.
If quantizing, will replace module.q_linear with chosen int8 linear implementation when loading checkpoint into model.
"""
if QUANTIZATION_LEVEL == 0:
return None

# be careful that these are not overriding some other parameter in the module.
# necessary because hooks cannot take arguments besides module itself
module.quantized_input_size = input_size
module.quantized_output_size = output_size
module.dtype = dtype

if QUANTIZATION_IS_LOAD_STATE_DICT is True:
module._register_load_state_dict_pre_hook(quantized_inference_pre_hook, with_module=True)
module.register_load_state_dict_post_hook(quantized_inference_post_hook)
else:
# if we aren't loading state dict, call hooks directly to initialize quantized weights on model creation
quantized_inference_pre_hook(module)
quantized_inference_post_hook(module)
return True # A placeholder to represent that module.q_linear will be replaced during load_state_dict on model