Skip to content

Commit 54e83db

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Add optional zero_start_index_M argument to triton fp8 rowwise quantization (pytorch#3628)
Summary: Pull Request resolved: pytorch#3628 X-link: facebookresearch/FBGEMM#705 In MOE models, many rows may be sparsely populated. There's no reason to run quantization on these empty values. This diff adds a new optional argument to fp8 rowwise quantization that allows skipping over the sparse region of rows. Reviewed By: jasonjk-park, jianyuh, jiawenliu64 Differential Revision: D68797978 fbshipit-source-id: 0142427bb9324592fa29d2e162f1edd8d9fd1c9c
1 parent 5b048ab commit 54e83db

File tree

2 files changed

+81
-11
lines changed

2 files changed

+81
-11
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def _test_quantize_fp8_row(
3737
use_triton: bool,
3838
device: torch.device,
3939
output_device: Optional[torch.device] = None,
40+
use_jagged: bool = False,
4041
use_scale_ub: bool = False,
4142
transpose_inputs: bool = False,
4243
) -> None:
@@ -49,16 +50,30 @@ def _test_quantize_fp8_row(
4950
for dim1, dim2 in itertools.combinations(dims, 2):
5051
dims_list = list(dims)
5152
dims_list[dim1], dims_list[dim2] = dims_list[dim2], dims_list[dim1]
52-
inputs.append(a.permute(dims_list))
53+
inputs.append(a.clone().permute(dims_list))
5354
scale_ub = (
5455
torch.tensor([1200], dtype=torch.float, device=device)
5556
if use_scale_ub
5657
else None
5758
)
5859
for input_a in inputs:
60+
# Apply sparsification if specified.
61+
zero_start_index_M = None
62+
if use_jagged:
63+
m_vals = torch.randint(
64+
0, input_a.shape[-1] + 1, (input_a.shape[:-1])
65+
)
66+
mask = torch.arange(input_a.shape[-1]).expand(
67+
input_a.shape[:-1] + (input_a.shape[-1],)
68+
) >= m_vals.unsqueeze(-1)
69+
# Set corresponding values to 0.
70+
input_a[mask] = 0.0
71+
# Generate nonzero tensor in same layout as input.
72+
zero_start_index_M = torch.count_nonzero(input_a, dim=-1)
5973
a_fp8, a_scale = quantize_fp8_row(
6074
input_a,
6175
scale_ub=scale_ub,
76+
zero_start_index_M=zero_start_index_M,
6277
use_triton=use_triton,
6378
output_device=output_device,
6479
)
@@ -73,7 +88,10 @@ def _test_quantize_fp8_row(
7388

7489
self.assertTrue(
7590
torch.allclose(
76-
input_a.to(device=output_device), a_torch, atol=2e-1, rtol=1e-1
91+
input_a.to(device=output_device),
92+
a_torch,
93+
atol=2e-1,
94+
rtol=1e-1,
7795
)
7896
)
7997

@@ -97,6 +115,18 @@ def _test_quantize_fp8_row(
97115
)
98116
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cpu"))
99117
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cpu"))
118+
# Test with zero_start_index_M
119+
_test_quantize_fp8_row((20, 30), True, torch.device("cuda"), use_jagged=True)
120+
_test_quantize_fp8_row(
121+
(6, 4, 2, 3), True, torch.device("cuda"), use_jagged=True
122+
)
123+
_test_quantize_fp8_row(
124+
(4, 2, 3),
125+
True,
126+
torch.device("cuda"),
127+
transpose_inputs=True,
128+
use_jagged=True,
129+
)
100130

101131
def test_scale_fp8_row(self) -> None:
102132
def _test_scale_fp8_row(

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

+49-9
Original file line numberDiff line numberDiff line change
@@ -2301,6 +2301,7 @@ def _kernel_quantize_fp8_row(
23012301
A_scale,
23022302
A_fp8,
23032303
scale_ub,
2304+
zero_start_index_M,
23042305
B,
23052306
M,
23062307
N,
@@ -2313,10 +2314,14 @@ def _kernel_quantize_fp8_row(
23132314
stride_om,
23142315
stride_on,
23152316
stride_ok,
2317+
stride_zb,
2318+
stride_zm,
2319+
stride_zn,
23162320
TL_FP8_DTYPE: tl.constexpr,
23172321
MAX_FP8: tl.constexpr,
23182322
EPS: tl.constexpr,
23192323
CLAMP_MAX: tl.constexpr,
2324+
JAGGED: tl.constexpr,
23202325
BLOCK_SIZE: tl.constexpr,
23212326
USE_INT64: tl.constexpr,
23222327
) -> None:
@@ -2347,10 +2352,14 @@ def _kernel_quantize_fp8_row(
23472352
stride_om (int): Stride of m dimension of output.
23482353
stride_on (int): Stride of n dimension of output.
23492354
stride_ok (int): Stride of k dimension of output.
2355+
stride_zb (int): Stride of b dimension of jagged index.
2356+
stride_zm (int): Stride of m dimension of jagged index.
2357+
stride_zn (int): Stride of n dimension of jagged index.
23502358
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
23512359
MAX_FP8 (float): Maxmimum expressible value for FP8.
23522360
EPS (float): Epsilon value for numerical stability.
23532361
CLAMP_MAX (bool): Whethar to apply scale_ub.
2362+
JAGGED (bool): Whether to use jagged indexing.
23542363
BLOCK_SIZE (int): Block size for reduction.
23552364
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
23562365
"""
@@ -2371,11 +2380,25 @@ def _kernel_quantize_fp8_row(
23712380
+ (pid % (M * N)) % N * stride_on
23722381
)
23732382

2383+
if JAGGED:
2384+
z_offset_base = (
2385+
pid // (M * N) * stride_zb
2386+
+ (pid % (M * N)) // N * stride_zm
2387+
+ (pid % (M * N)) % N * stride_zn
2388+
)
2389+
row_size = tl.load(zero_start_index_M + z_offset_base)
2390+
else:
2391+
row_size = K
2392+
2393+
blocks = tl.cdiv(row_size, BLOCK_SIZE)
2394+
23742395
# Calculate max.
23752396
cur_max = 0.0
2376-
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
2397+
for _k in range(0, blocks):
23772398
a = tl.load(
2378-
A + a_offset_base + n_offset * stride_ak, mask=n_offset < K, other=0.0
2399+
A + a_offset_base + n_offset * stride_ak,
2400+
mask=n_offset < row_size,
2401+
other=0.0,
23792402
)
23802403
tile_max = tl.max(tl.abs(a))
23812404
cur_max = tl.maximum(tile_max, cur_max)
@@ -2394,7 +2417,9 @@ def _kernel_quantize_fp8_row(
23942417

23952418
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
23962419
a = tl.load(
2397-
A + a_offset_base + n_offset * stride_ak, mask=n_offset < K, other=0.0
2420+
A + a_offset_base + n_offset * stride_ak,
2421+
mask=n_offset < row_size,
2422+
other=0.0,
23982423
)
23992424
a_fp8 = a * a_scale
24002425
# Clamp A to fp8 range to make sure there's no overflow.
@@ -2403,20 +2428,25 @@ def _kernel_quantize_fp8_row(
24032428
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
24042429
a_fp8.to(TL_FP8_DTYPE)
24052430
tl.store(
2406-
A_fp8 + a_fp8_offset_base + n_offset * stride_ok, a_fp8, mask=n_offset < K
2431+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
2432+
a_fp8,
2433+
mask=n_offset < K,
24072434
)
24082435
n_offset += BLOCK_SIZE
24092436

24102437

24112438
def triton_quantize_fp8_row(
2412-
a: Tensor, scale_ub: Optional[Tensor] = None
2439+
a: Tensor,
2440+
scale_ub: Optional[Tensor] = None,
2441+
zero_start_index_M: Optional[Tensor] = None,
24132442
) -> Tuple[Tensor, Tensor]:
24142443
"""
24152444
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
24162445
24172446
Args:
24182447
a (Tensor): higher precision input tensor of 4 dimension.
24192448
scale_ub (Tensor): Maximum allowed value for scale.
2449+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
24202450
24212451
Returns:
24222452
torch.Tensor: fp8 scaled tensor.
@@ -2436,6 +2466,7 @@ def triton_quantize_fp8_row(
24362466
a_scale,
24372467
a_fp8,
24382468
scale_ub,
2469+
zero_start_index_M,
24392470
a.shape[0],
24402471
a.shape[1],
24412472
a.shape[2],
@@ -2448,20 +2479,25 @@ def triton_quantize_fp8_row(
24482479
a_fp8.stride(1),
24492480
a_fp8.stride(2),
24502481
a_fp8.stride(3),
2482+
zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
2483+
zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
2484+
zero_start_index_M.stride(2) if zero_start_index_M is not None else None,
24512485
TL_FP8_DTYPE=tl_dtype,
24522486
MAX_FP8=max_fp8,
24532487
EPS=eps,
24542488
CLAMP_MAX=scale_ub is not None,
2489+
JAGGED=zero_start_index_M is not None,
24552490
USE_INT64=use_int64,
24562491
)
24572492

2458-
return a_fp8, a_scale.view(a.shape[:-1])
2493+
return a_fp8, a_scale
24592494

24602495

24612496
@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
24622497
def quantize_fp8_row(
24632498
a: Tensor,
24642499
scale_ub: Optional[Tensor] = None,
2500+
zero_start_index_M: Optional[Tensor] = None,
24652501
use_triton: bool = True,
24662502
output_device: Optional[torch.device] = None,
24672503
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -2471,6 +2507,7 @@ def quantize_fp8_row(
24712507
Args:
24722508
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
24732509
scale_ub (Tensor): Maximum allowed value for scale.
2510+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
24742511
use_triton (bool): Whether to use triton kernel or pytorch.
24752512
output_device (torch.device): Device to optionally move the scaled tensors to.
24762513
@@ -2489,8 +2526,11 @@ def quantize_fp8_row(
24892526
a_shape = a.shape
24902527
while a.dim() < 4:
24912528
a = a.unsqueeze(0)
2492-
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub)
2493-
return a_fp8.view(a_shape), a_scale
2529+
if zero_start_index_M is not None:
2530+
while zero_start_index_M.dim() < 3:
2531+
zero_start_index_M = zero_start_index_M.unsqueeze(0)
2532+
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
2533+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
24942534
# else use pytorch implementation.
24952535
if not output_device:
24962536
output_device = a.device
@@ -2513,7 +2553,7 @@ def quantize_fp8_row(
25132553
a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
25142554
a_scale = a_scale.to(output_device) # pyre-ignore
25152555
del a
2516-
return a_fp8, (1 / a_scale).view(a_shape.shape[:-1]) # pyre-ignore
2556+
return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
25172557

25182558

25192559
@quantize_fp8_row.register_fake

0 commit comments

Comments
 (0)