diff --git a/README.md b/README.md index 5991d55b..8bf2792b 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,10 @@ For Image-Video Mixture Fine-tuning, make sure to enable the --group_frame optio - [ ] fp8 support - [ ] faster load model and save model support +## Contributing + +We welcome all contributions. Please run bash format.sh before submitting a pull request. + ## Acknowledgement We learned and reused code from the following projects: [PCM](https://github.com/G-U-N/Phased-Consistency-Model), [diffusers](https://github.com/huggingface/diffusers), [OpenSoraPlan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), and [xDiT](https://github.com/xdit-project/xDiT). diff --git a/fastvideo/distill.py b/fastvideo/distill.py index 89f02e13..34011ae4 100644 --- a/fastvideo/distill.py +++ b/fastvideo/distill.py @@ -631,6 +631,8 @@ def get_num_phases(multi_phased_distill_schedule, step): # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) + parser.add_argument("--num_height", type=int, default=480) + parser.add_argument("--num_width", type=int, default=848) parser.add_argument("--num_frames", type=int, default=163) parser.add_argument( "--dataloader_num_workers", diff --git a/fastvideo/distill_adv.py b/fastvideo/distill_adv.py index 51a20489..5d3cd942 100644 --- a/fastvideo/distill_adv.py +++ b/fastvideo/distill_adv.py @@ -425,7 +425,7 @@ def main(args): params_to_optimize, lr=args.learning_rate, betas=(0.9, 0.999), - weight_decay=1e-3, + weight_decay=args.weight_decay, eps=1e-8, ) @@ -433,7 +433,7 @@ def main(args): discriminator.parameters(), lr=args.discriminator_learning_rate, betas=(0, 0.999), - weight_decay=1e-3, + weight_decay=args.weight_decay, eps=1e-8, ) @@ -634,8 +634,7 @@ def get_num_phases(multi_phased_distill_schedule, step): # args.output_dir, # step, # ) - save_checkpoint(transformer, rank, args.output_dir, - args.max_train_steps) + save_checkpoint(transformer, rank, args.output_dir, step) main_print(f"--> checkpoint saved at step {step}") dist.barrier() if args.log_validation and step % args.validation_steps == 0: @@ -673,6 +672,8 @@ def get_num_phases(multi_phased_distill_schedule, step): help="The type of model to train.") # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) + parser.add_argument("--num_height", type=int, default=480) + parser.add_argument("--num_width", type=int, default=848) parser.add_argument("--num_frames", type=int, default=163) parser.add_argument( "--dataloader_num_workers", @@ -922,6 +923,16 @@ def get_num_phases(multi_phased_distill_schedule, step): default=2, help="The stride of the discriminator head.", ) + parser.add_argument( + "--linear_range", + type=float, + default=0.5, + help="Range for linear quadratic scheduler.", + ) + parser.add_argument("--weight_decay", + type=float, + default=0.001, + help="Weight decay to apply.") parser.add_argument( "--linear_quadratic_threshold", type=float, diff --git a/fastvideo/train.py b/fastvideo/train.py index ecb7cb92..33b96df0 100644 --- a/fastvideo/train.py +++ b/fastvideo/train.py @@ -472,6 +472,8 @@ def main(args): help="The type of model to train.") # dataset & dataloader parser.add_argument("--data_json_path", type=str, required=True) + parser.add_argument("--num_height", type=int, default=480) + parser.add_argument("--num_width", type=int, default=848) parser.add_argument("--num_frames", type=int, default=163) parser.add_argument( "--dataloader_num_workers", diff --git a/fastvideo/utils/checkpoint.py b/fastvideo/utils/checkpoint.py index b84fd346..02fa2e58 100644 --- a/fastvideo/utils/checkpoint.py +++ b/fastvideo/utils/checkpoint.py @@ -19,12 +19,12 @@ from fastvideo.utils.logging_ import main_print -def save_checkpoint(model, - optimizer, - rank, - output_dir, - step, - discriminator=False): +def save_checkpoint_optimizer(model, + optimizer, + rank, + output_dir, + step, + discriminator=False): with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, @@ -60,6 +60,32 @@ def save_checkpoint(model, torch.save(optim_state, optimizer_path) +def save_checkpoint(transformer, rank, output_dir, step): + main_print(f"--> saving checkpoint at step {step}") + with FSDP.state_dict_type( + transformer, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state = transformer.state_dict() + # todo move to get_state_dict + if rank <= 0: + save_dir = os.path.join(output_dir, f"checkpoint-{step}") + os.makedirs(save_dir, exist_ok=True) + # save using safetensors + weight_path = os.path.join(save_dir, + "diffusion_pytorch_model.safetensors") + save_file(cpu_state, weight_path) + config_dict = dict(transformer.config) + if "dtype" in config_dict: + del config_dict["dtype"] # TODO + config_path = os.path.join(save_dir, "config.json") + # save dict as json + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=4) + main_print(f"--> checkpoint saved at step {step}") + + def save_checkpoint_generator_discriminator( model, optimizer, diff --git a/fastvideo/utils/load.py b/fastvideo/utils/load.py index 8637feb1..8931cb76 100644 --- a/fastvideo/utils/load.py +++ b/fastvideo/utils/load.py @@ -274,6 +274,8 @@ def load_transformer( ) transformer = load_hunyuan_state_dict(transformer, dit_model_name_or_path) + if master_weight_type == torch.bfloat16: + transformer = transformer.bfloat16() else: raise ValueError(f"Unsupported model type: {model_type}") return transformer diff --git a/fastvideo/utils/validation.py b/fastvideo/utils/validation.py index 09ec70a4..06f94dbf 100644 --- a/fastvideo/utils/validation.py +++ b/fastvideo/utils/validation.py @@ -4,7 +4,6 @@ import numpy as np import torch -import wandb from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils import export_to_video from diffusers.utils.torch_utils import randn_tensor @@ -12,6 +11,7 @@ from einops import rearrange from tqdm import tqdm +import wandb from fastvideo.distill.solver import PCMFMScheduler from fastvideo.models.mochi_hf.pipeline_mochi import ( linear_quadratic_schedule, retrieve_timesteps) @@ -294,8 +294,8 @@ def log_validation( scheduler_type=scheduler_type, num_frames=args.num_frames, # Peiyuan TODO: remove hardcode - height=480, - width=848, + height=args.num_height, + width=args.num_width, num_inference_steps=validation_sampling_step, guidance_scale=validation_guidance_scale, generator=generator, diff --git a/scripts/distill/distill_hunyuan.sh b/scripts/distill/distill_hunyuan.sh index 1fff28df..4dd63af0 100644 --- a/scripts/distill/distill_hunyuan.sh +++ b/scripts/distill/distill_hunyuan.sh @@ -2,7 +2,52 @@ export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online DATA_DIR=./data +IP=[MASTER NODE IP] +torchrun --nnodes 4 --nproc_per_node 8\ + --node_rank=0 \ + --rdzv_id=456 \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$IP:29500 \ + fastvideo/distill.py\ + --seed 42\ + --pretrained_model_name_or_path $DATA_DIR/hunyuan\ + --dit_model_name_or_path $DATA_DIR/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\ + --model_type "hunyuan" \ + --cache_dir "$DATA_DIR/.cache"\ + --data_json_path "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/videos2caption.json"\ + --validation_prompt_dir "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/validation"\ + --gradient_checkpointing\ + --train_batch_size=1\ + --num_latent_t 32 \ + --sp_size 2 \ + --train_sp_batch_size 1\ + --dataloader_num_workers 4\ + --gradient_accumulation_steps=1\ + --max_train_steps=320\ + --learning_rate=1e-6\ + --mixed_precision="bf16"\ + --checkpointing_steps=64\ + --validation_steps 64\ + --validation_sampling_steps "2,4,8" \ + --checkpoints_total_limit 3\ + --allow_tf32\ + --ema_start_step 0\ + --cfg 0.0\ + --log_validation\ + --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\ + --tracker_project_name Hunyuan_Distill \ + --num_height 720 \ + --num_width 1280 \ + --num_frames 125 \ + --shift 17 \ + --validation_guidance_scale "1.0" \ + --num_euler_timesteps 50 \ + --multi_phased_distill_schedule "4000-1" \ + --not_apply_cfg_solver + + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t torchrun --nnodes 1 --nproc_per_node 8\ fastvideo/distill.py\ --seed 42\ @@ -10,16 +55,16 @@ torchrun --nnodes 1 --nproc_per_node 8\ --dit_model_name_or_path $DATA_DIR/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\ --model_type "hunyuan" \ --cache_dir "$DATA_DIR/.cache"\ - --data_json_path "$DATA_DIR/Hunyuan-30K-Distill-Data/videos2caption.json"\ - --validation_prompt_dir "$DATA_DIR/Hunyuan-Distill-Data/validation"\ + --data_json_path "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/videos2caption.json"\ + --validation_prompt_dir "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/validation"\ --gradient_checkpointing\ --train_batch_size=1\ - --num_latent_t 24\ - --sp_size 1\ + --num_latent_t 32 \ + --sp_size 2 \ --train_sp_batch_size 1\ --dataloader_num_workers 4\ --gradient_accumulation_steps=1\ - --max_train_steps=2000\ + --max_train_steps=320\ --learning_rate=1e-6\ --mixed_precision="bf16"\ --checkpointing_steps=64\ @@ -30,9 +75,11 @@ torchrun --nnodes 1 --nproc_per_node 8\ --ema_start_step 0\ --cfg 0.0\ --log_validation\ - --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_32"\ + --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\ --tracker_project_name Hunyuan_Distill \ - --num_frames 93 \ + --num_height 720 \ + --num_width 1280 \ + --num_frames 125 \ --shift 17 \ --validation_guidance_scale "1.0" \ --num_euler_timesteps 50 \ diff --git a/scripts/finetune/finetune_hunyuan.sh b/scripts/finetune/finetune_hunyuan.sh index 5509dfa7..12e8e43a 100644 --- a/scripts/finetune/finetune_hunyuan.sh +++ b/scripts/finetune/finetune_hunyuan.sh @@ -32,5 +32,7 @@ torchrun --nnodes 1 --nproc_per_node 8 \ --output_dir=data/outputs/HSH-Taylor-Finetune-Hunyuan \ --tracker_project_name HSH-Taylor-Finetune-Hunyuan \ --num_frames 93 \ + --num_height 720 \ + --num_width 1280 \ --validation_guidance_scale "1.0" \ --group_frame \ No newline at end of file