@@ -2301,6 +2301,7 @@ def _kernel_quantize_fp8_row(
2301
2301
A_scale ,
2302
2302
A_fp8 ,
2303
2303
scale_ub ,
2304
+ zero_start_index_M ,
2304
2305
B ,
2305
2306
M ,
2306
2307
N ,
@@ -2313,10 +2314,14 @@ def _kernel_quantize_fp8_row(
2313
2314
stride_om ,
2314
2315
stride_on ,
2315
2316
stride_ok ,
2317
+ stride_zb ,
2318
+ stride_zm ,
2319
+ stride_zn ,
2316
2320
TL_FP8_DTYPE : tl .constexpr ,
2317
2321
MAX_FP8 : tl .constexpr ,
2318
2322
EPS : tl .constexpr ,
2319
2323
CLAMP_MAX : tl .constexpr ,
2324
+ JAGGED : tl .constexpr ,
2320
2325
BLOCK_SIZE : tl .constexpr ,
2321
2326
USE_INT64 : tl .constexpr ,
2322
2327
) -> None :
@@ -2347,10 +2352,14 @@ def _kernel_quantize_fp8_row(
2347
2352
stride_om (int): Stride of m dimension of output.
2348
2353
stride_on (int): Stride of n dimension of output.
2349
2354
stride_ok (int): Stride of k dimension of output.
2355
+ stride_zb (int): Stride of b dimension of jagged index.
2356
+ stride_zm (int): Stride of m dimension of jagged index.
2357
+ stride_zn (int): Stride of n dimension of jagged index.
2350
2358
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
2351
2359
MAX_FP8 (float): Maxmimum expressible value for FP8.
2352
2360
EPS (float): Epsilon value for numerical stability.
2353
2361
CLAMP_MAX (bool): Whethar to apply scale_ub.
2362
+ JAGGED (bool): Whether to use jagged indexing.
2354
2363
BLOCK_SIZE (int): Block size for reduction.
2355
2364
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
2356
2365
"""
@@ -2371,11 +2380,25 @@ def _kernel_quantize_fp8_row(
2371
2380
+ (pid % (M * N )) % N * stride_on
2372
2381
)
2373
2382
2383
+ if JAGGED :
2384
+ z_offset_base = (
2385
+ pid // (M * N ) * stride_zb
2386
+ + (pid % (M * N )) // N * stride_zm
2387
+ + (pid % (M * N )) % N * stride_zn
2388
+ )
2389
+ row_size = tl .load (zero_start_index_M + z_offset_base )
2390
+ else :
2391
+ row_size = K
2392
+
2393
+ blocks = tl .cdiv (row_size , BLOCK_SIZE )
2394
+
2374
2395
# Calculate max.
2375
2396
cur_max = 0.0
2376
- for _k in range (0 , tl . cdiv ( K , BLOCK_SIZE ) ):
2397
+ for _k in range (0 , blocks ):
2377
2398
a = tl .load (
2378
- A + a_offset_base + n_offset * stride_ak , mask = n_offset < K , other = 0.0
2399
+ A + a_offset_base + n_offset * stride_ak ,
2400
+ mask = n_offset < row_size ,
2401
+ other = 0.0 ,
2379
2402
)
2380
2403
tile_max = tl .max (tl .abs (a ))
2381
2404
cur_max = tl .maximum (tile_max , cur_max )
@@ -2394,7 +2417,9 @@ def _kernel_quantize_fp8_row(
2394
2417
2395
2418
for _k in range (0 , tl .cdiv (K , BLOCK_SIZE )):
2396
2419
a = tl .load (
2397
- A + a_offset_base + n_offset * stride_ak , mask = n_offset < K , other = 0.0
2420
+ A + a_offset_base + n_offset * stride_ak ,
2421
+ mask = n_offset < row_size ,
2422
+ other = 0.0 ,
2398
2423
)
2399
2424
a_fp8 = a * a_scale
2400
2425
# Clamp A to fp8 range to make sure there's no overflow.
@@ -2403,20 +2428,25 @@ def _kernel_quantize_fp8_row(
2403
2428
a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 )
2404
2429
a_fp8 .to (TL_FP8_DTYPE )
2405
2430
tl .store (
2406
- A_fp8 + a_fp8_offset_base + n_offset * stride_ok , a_fp8 , mask = n_offset < K
2431
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok ,
2432
+ a_fp8 ,
2433
+ mask = n_offset < K ,
2407
2434
)
2408
2435
n_offset += BLOCK_SIZE
2409
2436
2410
2437
2411
2438
def triton_quantize_fp8_row (
2412
- a : Tensor , scale_ub : Optional [Tensor ] = None
2439
+ a : Tensor ,
2440
+ scale_ub : Optional [Tensor ] = None ,
2441
+ zero_start_index_M : Optional [Tensor ] = None ,
2413
2442
) -> Tuple [Tensor , Tensor ]:
2414
2443
"""
2415
2444
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
2416
2445
2417
2446
Args:
2418
2447
a (Tensor): higher precision input tensor of 4 dimension.
2419
2448
scale_ub (Tensor): Maximum allowed value for scale.
2449
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2420
2450
2421
2451
Returns:
2422
2452
torch.Tensor: fp8 scaled tensor.
@@ -2436,6 +2466,7 @@ def triton_quantize_fp8_row(
2436
2466
a_scale ,
2437
2467
a_fp8 ,
2438
2468
scale_ub ,
2469
+ zero_start_index_M ,
2439
2470
a .shape [0 ],
2440
2471
a .shape [1 ],
2441
2472
a .shape [2 ],
@@ -2448,20 +2479,25 @@ def triton_quantize_fp8_row(
2448
2479
a_fp8 .stride (1 ),
2449
2480
a_fp8 .stride (2 ),
2450
2481
a_fp8 .stride (3 ),
2482
+ zero_start_index_M .stride (0 ) if zero_start_index_M is not None else None ,
2483
+ zero_start_index_M .stride (1 ) if zero_start_index_M is not None else None ,
2484
+ zero_start_index_M .stride (2 ) if zero_start_index_M is not None else None ,
2451
2485
TL_FP8_DTYPE = tl_dtype ,
2452
2486
MAX_FP8 = max_fp8 ,
2453
2487
EPS = eps ,
2454
2488
CLAMP_MAX = scale_ub is not None ,
2489
+ JAGGED = zero_start_index_M is not None ,
2455
2490
USE_INT64 = use_int64 ,
2456
2491
)
2457
2492
2458
- return a_fp8 , a_scale . view ( a . shape [: - 1 ])
2493
+ return a_fp8 , a_scale
2459
2494
2460
2495
2461
2496
@torch .library .custom_op ("triton::quantize_fp8_row" , mutates_args = ())
2462
2497
def quantize_fp8_row (
2463
2498
a : Tensor ,
2464
2499
scale_ub : Optional [Tensor ] = None ,
2500
+ zero_start_index_M : Optional [Tensor ] = None ,
2465
2501
use_triton : bool = True ,
2466
2502
output_device : Optional [torch .device ] = None ,
2467
2503
) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -2471,6 +2507,7 @@ def quantize_fp8_row(
2471
2507
Args:
2472
2508
a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
2473
2509
scale_ub (Tensor): Maximum allowed value for scale.
2510
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2474
2511
use_triton (bool): Whether to use triton kernel or pytorch.
2475
2512
output_device (torch.device): Device to optionally move the scaled tensors to.
2476
2513
@@ -2489,8 +2526,11 @@ def quantize_fp8_row(
2489
2526
a_shape = a .shape
2490
2527
while a .dim () < 4 :
2491
2528
a = a .unsqueeze (0 )
2492
- a_fp8 , a_scale = triton_quantize_fp8_row (a , scale_ub )
2493
- return a_fp8 .view (a_shape ), a_scale
2529
+ if zero_start_index_M is not None :
2530
+ while zero_start_index_M .dim () < 3 :
2531
+ zero_start_index_M = zero_start_index_M .unsqueeze (0 )
2532
+ a_fp8 , a_scale = triton_quantize_fp8_row (a , scale_ub , zero_start_index_M )
2533
+ return a_fp8 .view (a_shape ), a_scale .view (a_shape [:- 1 ])
2494
2534
# else use pytorch implementation.
2495
2535
if not output_device :
2496
2536
output_device = a .device
@@ -2513,7 +2553,7 @@ def quantize_fp8_row(
2513
2553
a_fp8 = a_fp8 .to (device = output_device , dtype = pt_dtype )
2514
2554
a_scale = a_scale .to (output_device ) # pyre-ignore
2515
2555
del a
2516
- return a_fp8 , (1 / a_scale ).view (a_shape . shape [:- 1 ]) # pyre-ignore
2556
+ return a_fp8 , (1 / a_scale ).view (a_shape [:- 1 ]) # pyre-ignore
2517
2557
2518
2558
2519
2559
@quantize_fp8_row .register_fake
0 commit comments