-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat]: Add tests for FastVideo (#127)
- Loading branch information
Showing
10 changed files
with
322 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
name: Run Tests | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Check out repository | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.12' # or any version you need | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip setuptools wheel | ||
pip install torch | ||
pip install packaging ninja | ||
pip install -e . | ||
pip install pytest | ||
- name: Run Pytest | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,6 @@ wandb/ | |
*.pt | ||
cache_dir/ | ||
wandb/ | ||
test* | ||
sample_video* | ||
sample_image* | ||
512* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import os | ||
import unittest | ||
|
||
import torch | ||
from transformers import AutoTokenizer, T5EncoderModel | ||
|
||
from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import \ | ||
AutoencoderKLCausal3D | ||
|
||
|
||
class TestAutoencoderKLCausal3D(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
""" | ||
setUpClass is called once, before any test is run. | ||
We can set environment variables or load heavy resources here. | ||
""" | ||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" | ||
|
||
# Load tokenizer/model that can be reused across all tests | ||
cls.tokenizer = AutoTokenizer.from_pretrained( | ||
"hf-internal-testing/tiny-random-t5") | ||
cls.text_encoder = T5EncoderModel.from_pretrained( | ||
"hf-internal-testing/tiny-random-t5") | ||
|
||
def setUp(self): | ||
""" | ||
setUp is called before each test method to prepare fresh state. | ||
""" | ||
self.batch_size = 1 | ||
self.init_time_len = 9 | ||
self.init_height = 16 | ||
self.init_width = 16 | ||
self.latent_channels = 4 | ||
self.spatial_compression_ratio = 8 | ||
self.time_compression_ratio = 4 | ||
|
||
# Model initialization config | ||
self.init_dict = { | ||
"in_channels": | ||
3, | ||
"out_channels": | ||
3, | ||
"latent_channels": | ||
self.latent_channels, | ||
"down_block_types": ( | ||
"DownEncoderBlockCausal3D", | ||
"DownEncoderBlockCausal3D", | ||
"DownEncoderBlockCausal3D", | ||
"DownEncoderBlockCausal3D", | ||
), | ||
"up_block_types": ( | ||
"UpDecoderBlockCausal3D", | ||
"UpDecoderBlockCausal3D", | ||
"UpDecoderBlockCausal3D", | ||
"UpDecoderBlockCausal3D", | ||
), | ||
"block_out_channels": (8, 8, 8, 8), | ||
"layers_per_block": | ||
1, | ||
"act_fn": | ||
"silu", | ||
"norm_num_groups": | ||
4, | ||
"scaling_factor": | ||
0.476986, | ||
"spatial_compression_ratio": | ||
self.spatial_compression_ratio, | ||
"time_compression_ratio": | ||
self.time_compression_ratio, | ||
"mid_block_add_attention": | ||
True, | ||
} | ||
|
||
# Instantiate the model | ||
self.model = AutoencoderKLCausal3D(**self.init_dict) | ||
|
||
# Create a random input tensor | ||
self.input_tensor = torch.rand(self.batch_size, 3, self.init_time_len, | ||
self.init_height, self.init_width) | ||
|
||
def test_encode_shape(self): | ||
""" | ||
Check that the shape of the encoded output matches expectations. | ||
""" | ||
vae_encoder_output = self.model.encode(self.input_tensor) | ||
|
||
# The distribution from the VAE has a .sample() method | ||
# so we verify the shape of that sample. | ||
sample_shape = vae_encoder_output["latent_dist"].sample().shape | ||
|
||
# We expect shape: [batch_size, latent_channels, | ||
# (init_time_len // time_compression_ratio) + 1, | ||
# init_height // spatial_compression_ratio, | ||
# init_width // spatial_compression_ratio] | ||
expected_shape = ( | ||
self.batch_size, | ||
self.latent_channels, | ||
(self.init_time_len // self.time_compression_ratio) + 1, | ||
self.init_height // self.spatial_compression_ratio, | ||
self.init_width // self.spatial_compression_ratio, | ||
) | ||
|
||
# (Optional) Print them if you like, or just rely on assertions: | ||
print(f"sample_shape: {sample_shape}") | ||
print(f"expected_shape: {expected_shape}") | ||
|
||
self.assertEqual( | ||
sample_shape, | ||
expected_shape, | ||
f"Encoded sample shape {sample_shape} does not match {expected_shape}.", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import shutil | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
|
||
@pytest.fixture(scope="module", autouse=True) | ||
def setup_distributed(): | ||
os.environ["RANK"] = "0" | ||
os.environ["WORLD_SIZE"] = "1" | ||
os.environ["LOCAL_RANK"] = "0" | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = "12345" | ||
|
||
dist.init_process_group("nccl") | ||
yield | ||
dist.destroy_process_group() | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), | ||
reason="Requires at least 2 GPUs to run NCCL tests") | ||
def test_save_and_remove_checkpoint(): | ||
from fastvideo.models.mochi_hf.modeling_mochi import \ | ||
MochiTransformer3DModel | ||
from fastvideo.utils.checkpoint import save_checkpoint | ||
from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs | ||
|
||
transformer = MochiTransformer3DModel(num_layers=0) | ||
fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none") | ||
transformer = FSDP(transformer, **fsdp_kwargs) | ||
|
||
test_folder = "./test_checkpoint" | ||
save_checkpoint(transformer, 0, test_folder, 0) | ||
|
||
assert os.path.exists(test_folder), "Checkpoint folder was not created." | ||
|
||
shutil.rmtree(test_folder) | ||
assert not os.path.exists(test_folder), "Checkpoint folder still exists." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from functools import partial | ||
from multiprocessing import Manager | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
|
||
from fastvideo.utils.communications import (nccl_info, | ||
prepare_sequence_parallel_data) | ||
|
||
|
||
def _init_distributed_test_gpu(rank, world_size, backend, port, data, results): | ||
dist.init_process_group( | ||
backend=backend, | ||
init_method=f"tcp://127.0.0.1:{port}", | ||
world_size=world_size, | ||
rank=rank, | ||
) | ||
|
||
device = torch.device(f"cuda:{rank}") | ||
|
||
nccl_info.sp_size = world_size | ||
nccl_info.rank_within_group = rank | ||
nccl_info.group_id = 0 | ||
|
||
seq_group = dist.new_group(ranks=list(range(world_size))) | ||
nccl_info.group = seq_group | ||
|
||
hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask = data | ||
hidden_states = hidden_states[rank].unsqueeze(dim=0).to(device) | ||
encoder_hidden_states = encoder_hidden_states.to(device) | ||
attention_mask = attention_mask.to(device) | ||
encoder_attention_mask = encoder_attention_mask.to(device) | ||
print(f"Rank {rank} input hidden_states:\n", hidden_states) | ||
print(f"Rank {rank} input hidden_states shape:\n", hidden_states.shape) | ||
out_hidden, out_encoder, out_attn_mask, out_encoder_mask = prepare_sequence_parallel_data( | ||
hidden_states, encoder_hidden_states, attention_mask, | ||
encoder_attention_mask) | ||
print(f"Rank {rank} output out_hidden:\n", out_hidden) | ||
|
||
shapes = ( | ||
out_hidden.shape, | ||
out_encoder.shape, | ||
out_attn_mask.shape, | ||
out_encoder_mask.shape, | ||
) | ||
shape_tensor = torch.tensor( | ||
[*shapes[0], *shapes[1], *shapes[2], *shapes[3]], | ||
dtype=torch.int32, | ||
device=device) | ||
shape_list = [torch.zeros_like(shape_tensor) for _ in range(world_size)] | ||
dist.all_gather(shape_list, shape_tensor, group=seq_group) | ||
gathered_shapes = [tuple(s.tolist()) for s in shape_list] | ||
out_hidden_cpu = out_hidden.to("cpu") | ||
|
||
results[rank] = { | ||
"shapes": gathered_shapes, | ||
"out_hidden": out_hidden_cpu, | ||
} | ||
|
||
dist.barrier() | ||
dist.destroy_process_group() | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available() | ||
or torch.cuda.device_count() < 2, | ||
reason="Requires at least 2 GPUs to run NCCL tests") | ||
def test_prepare_sequence_parallel_data_gpu(): | ||
world_size = 2 | ||
backend = "nccl" | ||
port = 12355 # or use a random free port if collisions occur | ||
|
||
# Create test tensors on CPU; the dimension at index=2 should be divisible by world_size=2 (if applicable). | ||
hidden_states = torch.randn(2, 1, 2, 1, 1) | ||
encoder_hidden_states = torch.randn(2, 2) | ||
attention_mask = torch.randn(2, 2) | ||
encoder_attention_mask = torch.randn(2, 2) | ||
|
||
print("init hidden states", hidden_states) | ||
|
||
manager = Manager() | ||
results_dict = manager.dict() | ||
|
||
# Wrap our helper function with partial | ||
mp_func = partial(_init_distributed_test_gpu, | ||
world_size=world_size, | ||
backend=backend, | ||
port=port, | ||
data=(hidden_states, encoder_hidden_states, | ||
attention_mask, encoder_attention_mask), | ||
results=results_dict) | ||
|
||
# Spawn two GPU processes (rank=0, rank=1) | ||
mp.spawn(mp_func, nprocs=world_size) | ||
|
||
first_rank_shapes = None | ||
|
||
overall_hidden_out = [] | ||
|
||
for rank in sorted(results_dict.keys()): | ||
rank_data = results_dict[rank] | ||
rank_shapes = rank_data["shapes"] | ||
if first_rank_shapes is None: | ||
first_rank_shapes = rank_shapes | ||
assert rank_shapes == first_rank_shapes, ( | ||
f"Mismatch in shapes across ranks: {rank_shapes} != {first_rank_shapes}" | ||
) | ||
overall_hidden_out.append(rank_data["out_hidden"]) | ||
|
||
overall_hidden_out = torch.cat(overall_hidden_out, dim=2) | ||
print("overall_hidden_out", overall_hidden_out) | ||
print("overall_hidden_out_shape", overall_hidden_out.shape) | ||
|
||
assert torch.allclose(hidden_states, | ||
torch.tensor(overall_hidden_out), | ||
rtol=1e-7, | ||
atol=1e-6) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_prepare_sequence_parallel_data_gpu() |