Skip to content

Commit

Permalink
add parallel for vae decoding (#134)
Browse files Browse the repository at this point in the history
Co-authored-by: “Peiyuan Zhang” <[email protected]>
  • Loading branch information
rucnyz and jzhang38 authored Jan 8, 2025
1 parent e0e05f9 commit 4a1f1e3
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,7 @@ def decode_latents(self, latents, enable_tiling=True):
latents = 1 / self.vae.config.scaling_factor * latents
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(latents, return_dict=False)[0]
else:
image = self.vae.decode(latents, return_dict=False)[0]
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
if image.ndim == 4:
Expand Down Expand Up @@ -605,6 +603,7 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_ver: str = "88-4c-sd",
enable_tiling: bool = False,
enable_vae_sp: bool = False,
n_tokens: Optional[int] = None,
embedded_guidance_scale: Optional[float] = None,
**kwargs,
Expand Down Expand Up @@ -986,13 +985,11 @@ def __call__(
enabled=vae_autocast_enabled):
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(latents,
return_dict=False,
generator=generator)[0]
else:
image = self.vae.decode(latents,
return_dict=False,
generator=generator)[0]
if enable_vae_sp:
self.vae.enable_parallel()
image = self.vae.decode(latents,
return_dict=False,
generator=generator)[0]

if expand_temporal_dim or image.shape[2] == 1:
image = image.squeeze(2)
Expand Down
1 change: 1 addition & 0 deletions fastvideo/models/hunyuan/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def predict(
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
enable_vae_sp=self.args.vae_sp,
)[0]
out_dict["samples"] = samples
out_dict["prompts"] = prompt
Expand Down
165 changes: 165 additions & 0 deletions fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
#
# ==============================================================================
from dataclasses import dataclass
from math import prod
from typing import Dict, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config

from fastvideo.utils.parallel_states import nccl_info

try:
# This diffusers is modified and packed in the mirror.
from diffusers.loaders import FromOriginalVAEMixin
Expand Down Expand Up @@ -119,6 +123,7 @@ def __init__(
self.use_slicing = False
self.use_spatial_tiling = False
self.use_temporal_tiling = False
self.use_parallel = False

# only relevant if vae tiling is enabled
self.tile_sample_min_tsize = sample_tsize
Expand Down Expand Up @@ -165,6 +170,12 @@ def disable_tiling(self):
self.disable_spatial_tiling()
self.disable_temporal_tiling()

def enable_parallel(self):
r"""
Enable sequence parallelism for the model. This will allow the vae to decode (with tiling) in parallel.
"""
self.use_parallel = True

def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
Expand Down Expand Up @@ -319,6 +330,9 @@ def _decode(
) -> Union[DecoderOutput, torch.FloatTensor]:
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."

if self.use_parallel:
return self.parallel_tiled_decode(z, return_dict=return_dict)

if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z, return_dict=return_dict)

Expand Down Expand Up @@ -591,6 +605,157 @@ def temporal_tiled_decode(self,

return DecoderOutput(sample=dec)

def _parallel_data_generator(self, gathered_results,
gathered_dim_metadata):
global_idx = 0
for i, per_rank_metadata in enumerate(gathered_dim_metadata):
_start_shape = 0
for shape in per_rank_metadata:
mul_shape = prod(shape)
yield (gathered_results[i, _start_shape:_start_shape +
mul_shape].reshape(shape), global_idx)
_start_shape += mul_shape
global_idx += 1

def parallel_tiled_decode(self,
z: torch.FloatTensor,
return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs
"""
world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group
B, C, T, H, W = z.shape

# Calculate parameters
t_overlap_size = int(self.tile_latent_min_tsize *
(1 - self.tile_overlap_factor))
t_blend_extent = int(self.tile_sample_min_tsize *
self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - t_blend_extent

s_overlap_size = int(self.tile_latent_min_size *
(1 - self.tile_overlap_factor))
s_blend_extent = int(self.tile_sample_min_size *
self.tile_overlap_factor)
s_row_limit = self.tile_sample_min_size - s_blend_extent

# Calculate tile dimensions
num_t_tiles = (T + t_overlap_size - 1) // t_overlap_size
num_h_tiles = (H + s_overlap_size - 1) // s_overlap_size
num_w_tiles = (W + s_overlap_size - 1) // s_overlap_size
total_spatial_tiles = num_h_tiles * num_w_tiles
total_tiles = num_t_tiles * total_spatial_tiles

# Calculate tiles per rank and padding
tiles_per_rank = (total_tiles + world_size - 1) // world_size
start_tile_idx = rank * tiles_per_rank
end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles)

local_results = []
local_dim_metadata = []
# Process assigned tiles
for local_idx, global_idx in enumerate(
range(start_tile_idx, end_tile_idx)):
# Convert flat index to 3D indices
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
w_idx = spatial_idx % num_w_tiles

# Calculate positions
t_start = t_idx * t_overlap_size
h_start = h_idx * s_overlap_size
w_start = w_idx * s_overlap_size

# Extract and process tile
tile = z[:, :, t_start:t_start + self.tile_latent_min_tsize + 1,
h_start:h_start + self.tile_latent_min_size,
w_start:w_start + self.tile_latent_min_size]

# Process tile
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)

if t_start > 0:
decoded = decoded[:, :, 1:, :, :]

# Store metadata
shape = decoded.shape
# Store decoded data (flattened)
decoded_flat = decoded.reshape(-1)
local_results.append(decoded_flat)
local_dim_metadata.append(shape)

results = torch.cat(local_results, dim=0).contiguous()
del local_results
torch.cuda.empty_cache()
# first gather size to pad the results
local_size = torch.tensor([results.size(0)],
device=results.device,
dtype=torch.int64)
all_sizes = [
torch.zeros(1, device=results.device, dtype=torch.int64)
for _ in range(world_size)
]
dist.all_gather(all_sizes, local_size)
max_size = max(size.item() for size in all_sizes)
padded_results = torch.zeros(max_size, device=results.device)
padded_results[:results.size(0)] = results
del results
torch.cuda.empty_cache()
# Gather all results
gathered_dim_metadata = [None] * world_size
gathered_results = torch.zeros_like(padded_results).repeat(
world_size, *[1] * len(padded_results.shape)
).contiguous(
) # use contiguous to make sure it won't copy data in the following operations
dist.all_gather_into_tensor(gathered_results, padded_results)
dist.all_gather_object(gathered_dim_metadata, local_dim_metadata)
# Process gathered results
data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)]
for _ in range(num_t_tiles)]
for current_data, global_idx in self._parallel_data_generator(
gathered_results, gathered_dim_metadata):
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
w_idx = spatial_idx % num_w_tiles
data[t_idx][h_idx][w_idx] = current_data
# Merge results
result_slices = []
last_slice_data = None
for i, tem_data in enumerate(data):
slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent,
s_row_limit)
if i > 0:
slice_data = self.blend_t(last_slice_data, slice_data,
t_blend_extent)
result_slices.append(slice_data[:, :, :t_limit, :, :])
else:
result_slices.append(slice_data[:, :, :t_limit + 1, :, :])
last_slice_data = slice_data
dec = torch.cat(result_slices, dim=2)

if not return_dict:
return (dec, )
return DecoderOutput(sample=dec)

def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit):
"""Helper function to merge spatial tiles with blending"""
result_rows = []
for i, row in enumerate(spatial_rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(spatial_rows[i - 1][j], tile,
blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
6 changes: 6 additions & 0 deletions fastvideo/sample/sample_t2v_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def main(args):
default="fp16",
choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vae-tiling", action="store_true", default=True)
parser.add_argument("--vae-sp", action="store_true", default=False)

parser.add_argument("--text-encoder", type=str, default="llm")
parser.add_argument(
Expand Down Expand Up @@ -234,4 +235,9 @@ def main(args):
parser.add_argument("--text-len-2", type=int, default=77)

args = parser.parse_args()
# process for vae sequence parallel
if args.vae_sp and not args.vae_tiling:
raise ValueError(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
main(args)
3 changes: 1 addition & 2 deletions scripts/inference/inference_diffusers_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \
--seed 1024 \
--output_path outputs_video/hunyuan_quant/nf4/ \
--model_path $MODEL_BASE \
--quantization "nf4" \
--cpu_offload
--quantization "nf4"
5 changes: 3 additions & 2 deletions scripts/inference/inference_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \
--flow-reverse \
--prompt ./assets/prompt.txt \
--seed 1024 \
--output_path outputs_video/hunyuan/cfg6/ \
--output_path outputs_video/hunyuan/vae_sp/ \
--model_path $MODEL_BASE \
--dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
--dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
--vae-sp

0 comments on commit 4a1f1e3

Please sign in to comment.