Skip to content

Commit

Permalink
[feat]: Add tests for FastVideo (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlsu9 authored Jan 6, 2025
1 parent dd75ee8 commit e0e05f9
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 13 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Run Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12' # or any version you need

- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install torch
pip install packaging ninja
pip install -e .
pip install pytest
- name: Run Pytest
run: |
pytest
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ wandb/
*.pt
cache_dir/
wandb/
test*
sample_video*
sample_image*
512*
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,12 @@ For Image-Video Mixture Fine-tuning, make sure to enable the --group_frame optio
- [ ] fp8 support
- [ ] faster load model and save model support

## Contributing
## 🤝 Contributing

We welcome all contributions. Please run bash format.sh before submitting a pull request.
We welcome all contributions. Please run `bash format.sh` before submitting a pull request.

## 🔧 Testing
Run `pytest` to verify the data preprocessing, checkpoint saving, and sequence parallel pipelines. We recommend adding corresponding test cases in the `test` folder to support your contribution.

## Acknowledgement
We learned and reused code from the following projects: [PCM](https://github.com/G-U-N/Phased-Consistency-Model), [diffusers](https://github.com/huggingface/diffusers), [OpenSoraPlan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), and [xDiT](https://github.com/xdit-project/xDiT).
Expand Down
5 changes: 2 additions & 3 deletions fastvideo/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,8 @@ def main(args):
wandb.init(project=project, config=args)

# Train!
total_batch_size = (args.train_batch_size * world_size *
args.gradient_accumulation_steps / args.sp_size *
args.train_sp_batch_size)
total_batch_size = (world_size * args.gradient_accumulation_steps /
args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
Expand Down
5 changes: 2 additions & 3 deletions fastvideo/distill_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,8 @@ def main(args):
wandb.init(project=project, config=args)

# Train!
total_batch_size = (args.train_batch_size * world_size *
args.gradient_accumulation_steps / args.sp_size *
args.train_sp_batch_size)
total_batch_size = (world_size * args.gradient_accumulation_steps /
args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
Expand Down
5 changes: 2 additions & 3 deletions fastvideo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,8 @@ def main(args):
wandb.init(project=project, config=args)

# Train!
total_batch_size = (args.train_batch_size * world_size *
args.gradient_accumulation_steps / args.sp_size *
args.train_sp_batch_size)
total_batch_size = (world_size * args.gradient_accumulation_steps /
args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
Expand Down
2 changes: 1 addition & 1 deletion fastvideo/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import numpy as np
import torch
import wandb
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from einops import rearrange
from tqdm import tqdm

import wandb
from fastvideo.distill.solver import PCMFMScheduler
from fastvideo.models.mochi_hf.pipeline_mochi import (
linear_quadratic_schedule, retrieve_timesteps)
Expand Down
117 changes: 117 additions & 0 deletions tests/test_data_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import unittest

import torch
from transformers import AutoTokenizer, T5EncoderModel

from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import \
AutoencoderKLCausal3D


class TestAutoencoderKLCausal3D(unittest.TestCase):

@classmethod
def setUpClass(cls):
"""
setUpClass is called once, before any test is run.
We can set environment variables or load heavy resources here.
"""
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

# Load tokenizer/model that can be reused across all tests
cls.tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-t5")
cls.text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5")

def setUp(self):
"""
setUp is called before each test method to prepare fresh state.
"""
self.batch_size = 1
self.init_time_len = 9
self.init_height = 16
self.init_width = 16
self.latent_channels = 4
self.spatial_compression_ratio = 8
self.time_compression_ratio = 4

# Model initialization config
self.init_dict = {
"in_channels":
3,
"out_channels":
3,
"latent_channels":
self.latent_channels,
"down_block_types": (
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
),
"up_block_types": (
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
),
"block_out_channels": (8, 8, 8, 8),
"layers_per_block":
1,
"act_fn":
"silu",
"norm_num_groups":
4,
"scaling_factor":
0.476986,
"spatial_compression_ratio":
self.spatial_compression_ratio,
"time_compression_ratio":
self.time_compression_ratio,
"mid_block_add_attention":
True,
}

# Instantiate the model
self.model = AutoencoderKLCausal3D(**self.init_dict)

# Create a random input tensor
self.input_tensor = torch.rand(self.batch_size, 3, self.init_time_len,
self.init_height, self.init_width)

def test_encode_shape(self):
"""
Check that the shape of the encoded output matches expectations.
"""
vae_encoder_output = self.model.encode(self.input_tensor)

# The distribution from the VAE has a .sample() method
# so we verify the shape of that sample.
sample_shape = vae_encoder_output["latent_dist"].sample().shape

# We expect shape: [batch_size, latent_channels,
# (init_time_len // time_compression_ratio) + 1,
# init_height // spatial_compression_ratio,
# init_width // spatial_compression_ratio]
expected_shape = (
self.batch_size,
self.latent_channels,
(self.init_time_len // self.time_compression_ratio) + 1,
self.init_height // self.spatial_compression_ratio,
self.init_width // self.spatial_compression_ratio,
)

# (Optional) Print them if you like, or just rely on assertions:
print(f"sample_shape: {sample_shape}")
print(f"expected_shape: {expected_shape}")

self.assertEqual(
sample_shape,
expected_shape,
f"Encoded sample shape {sample_shape} does not match {expected_shape}.",
)


if __name__ == "__main__":
unittest.main()
41 changes: 41 additions & 0 deletions tests/test_save_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import shutil

import pytest
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP


@pytest.fixture(scope="module", autouse=True)
def setup_distributed():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"

dist.init_process_group("nccl")
yield
dist.destroy_process_group()


@pytest.mark.skipif(not torch.cuda.is_available(),
reason="Requires at least 2 GPUs to run NCCL tests")
def test_save_and_remove_checkpoint():
from fastvideo.models.mochi_hf.modeling_mochi import \
MochiTransformer3DModel
from fastvideo.utils.checkpoint import save_checkpoint
from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs

transformer = MochiTransformer3DModel(num_layers=0)
fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none")
transformer = FSDP(transformer, **fsdp_kwargs)

test_folder = "./test_checkpoint"
save_checkpoint(transformer, 0, test_folder, 0)

assert os.path.exists(test_folder), "Checkpoint folder was not created."

shutil.rmtree(test_folder)
assert not os.path.exists(test_folder), "Checkpoint folder still exists."
122 changes: 122 additions & 0 deletions tests/test_sequence_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from functools import partial
from multiprocessing import Manager

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from fastvideo.utils.communications import (nccl_info,
prepare_sequence_parallel_data)


def _init_distributed_test_gpu(rank, world_size, backend, port, data, results):
dist.init_process_group(
backend=backend,
init_method=f"tcp://127.0.0.1:{port}",
world_size=world_size,
rank=rank,
)

device = torch.device(f"cuda:{rank}")

nccl_info.sp_size = world_size
nccl_info.rank_within_group = rank
nccl_info.group_id = 0

seq_group = dist.new_group(ranks=list(range(world_size)))
nccl_info.group = seq_group

hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask = data
hidden_states = hidden_states[rank].unsqueeze(dim=0).to(device)
encoder_hidden_states = encoder_hidden_states.to(device)
attention_mask = attention_mask.to(device)
encoder_attention_mask = encoder_attention_mask.to(device)
print(f"Rank {rank} input hidden_states:\n", hidden_states)
print(f"Rank {rank} input hidden_states shape:\n", hidden_states.shape)
out_hidden, out_encoder, out_attn_mask, out_encoder_mask = prepare_sequence_parallel_data(
hidden_states, encoder_hidden_states, attention_mask,
encoder_attention_mask)
print(f"Rank {rank} output out_hidden:\n", out_hidden)

shapes = (
out_hidden.shape,
out_encoder.shape,
out_attn_mask.shape,
out_encoder_mask.shape,
)
shape_tensor = torch.tensor(
[*shapes[0], *shapes[1], *shapes[2], *shapes[3]],
dtype=torch.int32,
device=device)
shape_list = [torch.zeros_like(shape_tensor) for _ in range(world_size)]
dist.all_gather(shape_list, shape_tensor, group=seq_group)
gathered_shapes = [tuple(s.tolist()) for s in shape_list]
out_hidden_cpu = out_hidden.to("cpu")

results[rank] = {
"shapes": gathered_shapes,
"out_hidden": out_hidden_cpu,
}

dist.barrier()
dist.destroy_process_group()


@pytest.mark.skipif(not torch.cuda.is_available()
or torch.cuda.device_count() < 2,
reason="Requires at least 2 GPUs to run NCCL tests")
def test_prepare_sequence_parallel_data_gpu():
world_size = 2
backend = "nccl"
port = 12355 # or use a random free port if collisions occur

# Create test tensors on CPU; the dimension at index=2 should be divisible by world_size=2 (if applicable).
hidden_states = torch.randn(2, 1, 2, 1, 1)
encoder_hidden_states = torch.randn(2, 2)
attention_mask = torch.randn(2, 2)
encoder_attention_mask = torch.randn(2, 2)

print("init hidden states", hidden_states)

manager = Manager()
results_dict = manager.dict()

# Wrap our helper function with partial
mp_func = partial(_init_distributed_test_gpu,
world_size=world_size,
backend=backend,
port=port,
data=(hidden_states, encoder_hidden_states,
attention_mask, encoder_attention_mask),
results=results_dict)

# Spawn two GPU processes (rank=0, rank=1)
mp.spawn(mp_func, nprocs=world_size)

first_rank_shapes = None

overall_hidden_out = []

for rank in sorted(results_dict.keys()):
rank_data = results_dict[rank]
rank_shapes = rank_data["shapes"]
if first_rank_shapes is None:
first_rank_shapes = rank_shapes
assert rank_shapes == first_rank_shapes, (
f"Mismatch in shapes across ranks: {rank_shapes} != {first_rank_shapes}"
)
overall_hidden_out.append(rank_data["out_hidden"])

overall_hidden_out = torch.cat(overall_hidden_out, dim=2)
print("overall_hidden_out", overall_hidden_out)
print("overall_hidden_out_shape", overall_hidden_out.shape)

assert torch.allclose(hidden_states,
torch.tensor(overall_hidden_out),
rtol=1e-7,
atol=1e-6)


if __name__ == "__main__":
test_prepare_sequence_parallel_data_gpu()

0 comments on commit e0e05f9

Please sign in to comment.