Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
BrianChen1129 committed Jan 8, 2025
1 parent 553cee3 commit f53208c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 30 deletions.
71 changes: 42 additions & 29 deletions fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
#
# ==============================================================================
from dataclasses import dataclass
from math import prod
from typing import Dict, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from math import prod

from fastvideo.utils.parallel_states import (get_sequence_parallel_state,
nccl_info)
import torch.distributed as dist
from fastvideo.utils.parallel_states import nccl_info

try:
# This diffusers is modified and packed in the mirror.
Expand Down Expand Up @@ -606,20 +605,18 @@ def temporal_tiled_decode(self,

return DecoderOutput(sample=dec)

def _parallel_data_generator(self, gathered_results, gathered_dim_metadata):
def _parallel_data_generator(self, gathered_results,
gathered_dim_metadata):
global_idx = 0
for i, per_rank_metadata in enumerate(gathered_dim_metadata):
_start_shape = 0
for shape in per_rank_metadata:
mul_shape = prod(shape)
yield (
gathered_results[i, _start_shape:_start_shape + mul_shape].reshape(shape),
global_idx
)
yield (gathered_results[i, _start_shape:_start_shape +
mul_shape].reshape(shape), global_idx)
_start_shape += mul_shape
global_idx += 1


def parallel_tiled_decode(self,
z: torch.FloatTensor,
return_dict: bool = True
Expand All @@ -631,12 +628,16 @@ def parallel_tiled_decode(self,
B, C, T, H, W = z.shape

# Calculate parameters
t_overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
t_blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
t_overlap_size = int(self.tile_latent_min_tsize *
(1 - self.tile_overlap_factor))
t_blend_extent = int(self.tile_sample_min_tsize *
self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - t_blend_extent

s_overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
s_blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
s_overlap_size = int(self.tile_latent_min_size *
(1 - self.tile_overlap_factor))
s_blend_extent = int(self.tile_sample_min_size *
self.tile_overlap_factor)
s_row_limit = self.tile_sample_min_size - s_blend_extent

# Calculate tile dimensions
Expand All @@ -654,7 +655,8 @@ def parallel_tiled_decode(self,
local_results = []
local_dim_metadata = []
# Process assigned tiles
for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)):
for local_idx, global_idx in enumerate(
range(start_tile_idx, end_tile_idx)):
# Convert flat index to 3D indices
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
Expand All @@ -667,10 +669,9 @@ def parallel_tiled_decode(self,
w_start = w_idx * s_overlap_size

# Extract and process tile
tile = z[:, :,
t_start:t_start + self.tile_latent_min_tsize + 1,
h_start:h_start + self.tile_latent_min_size,
w_start:w_start + self.tile_latent_min_size]
tile = z[:, :, t_start:t_start + self.tile_latent_min_tsize + 1,
h_start:h_start + self.tile_latent_min_size,
w_start:w_start + self.tile_latent_min_size]

# Process tile
tile = self.post_quant_conv(tile)
Expand All @@ -690,8 +691,13 @@ def parallel_tiled_decode(self,
del local_results
torch.cuda.empty_cache()
# first gather size to pad the results
local_size = torch.tensor([results.size(0)], device=results.device, dtype = torch.int64)
all_sizes = [torch.zeros(1, device=results.device, dtype = torch.int64) for _ in range(world_size)]
local_size = torch.tensor([results.size(0)],
device=results.device,
dtype=torch.int64)
all_sizes = [
torch.zeros(1, device=results.device, dtype=torch.int64)
for _ in range(world_size)
]
dist.all_gather(all_sizes, local_size)
max_size = max(size.item() for size in all_sizes)
padded_results = torch.zeros(max_size, device=results.device)
Expand All @@ -700,12 +706,17 @@ def parallel_tiled_decode(self,
torch.cuda.empty_cache()
# Gather all results
gathered_dim_metadata = [None] * world_size
gathered_results = torch.zeros_like(padded_results).repeat(world_size, *[1] * len(padded_results.shape)).contiguous() # use contiguous to make sure it won't copy data in the following operations
gathered_results = torch.zeros_like(padded_results).repeat(
world_size, *[1] * len(padded_results.shape)
).contiguous(
) # use contiguous to make sure it won't copy data in the following operations
dist.all_gather_into_tensor(gathered_results, padded_results)
dist.all_gather_object(gathered_dim_metadata, local_dim_metadata)
# Process gathered results
data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] for _ in range(num_t_tiles)]
for current_data, global_idx in self._parallel_data_generator(gathered_results, gathered_dim_metadata):
data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)]
for _ in range(num_t_tiles)]
for current_data, global_idx in self._parallel_data_generator(
gathered_results, gathered_dim_metadata):
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
Expand All @@ -715,28 +726,30 @@ def parallel_tiled_decode(self,
result_slices = []
last_slice_data = None
for i, tem_data in enumerate(data):
slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent, s_row_limit)
slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent,
s_row_limit)
if i > 0:
slice_data = self.blend_t(last_slice_data, slice_data, t_blend_extent)
slice_data = self.blend_t(last_slice_data, slice_data,
t_blend_extent)
result_slices.append(slice_data[:, :, :t_limit, :, :])
else:
result_slices.append(slice_data[:, :, :t_limit + 1, :, :])
last_slice_data = slice_data
dec = torch.cat(result_slices, dim=2)

if not return_dict:
return (dec,)
return (dec, )
return DecoderOutput(sample=dec)


def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit):
"""Helper function to merge spatial tiles with blending"""
result_rows = []
for i, row in enumerate(spatial_rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(spatial_rows[i - 1][j], tile, blend_extent)
tile = self.blend_v(spatial_rows[i - 1][j], tile,
blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
Expand Down
4 changes: 3 additions & 1 deletion fastvideo/sample/sample_t2v_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,7 @@ def main(args):
args = parser.parse_args()
# process for vae sequence parallel
if args.vae_sp and not args.vae_tiling:
raise ValueError("Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True.")
raise ValueError(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
main(args)

0 comments on commit f53208c

Please sign in to comment.