Skip to content

Commit

Permalink
[Fix] Save CK, Dataset bug fix (hao-ai-lab#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang38 authored Jan 1, 2025
1 parent cff5621 commit 0c7e0bd
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 20 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ For Image-Video Mixture Fine-tuning, make sure to enable the --group_frame optio
- [ ] fp8 support
- [ ] faster load model and save model support

## Contributing

We welcome all contributions. Please run bash format.sh before submitting a pull request.

## Acknowledgement
We learned and reused code from the following projects: [PCM](https://github.com/G-U-N/Phased-Consistency-Model), [diffusers](https://github.com/huggingface/diffusers), [OpenSoraPlan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), and [xDiT](https://github.com/xdit-project/xDiT).

Expand Down
2 changes: 2 additions & 0 deletions fastvideo/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,8 @@ def get_num_phases(multi_phased_distill_schedule, step):

# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
Expand Down
19 changes: 15 additions & 4 deletions fastvideo/distill_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ def main(args):
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=1e-3,
weight_decay=args.weight_decay,
eps=1e-8,
)

discriminator_optimizer = torch.optim.AdamW(
discriminator.parameters(),
lr=args.discriminator_learning_rate,
betas=(0, 0.999),
weight_decay=1e-3,
weight_decay=args.weight_decay,
eps=1e-8,
)

Expand Down Expand Up @@ -634,8 +634,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
# args.output_dir,
# step,
# )
save_checkpoint(transformer, rank, args.output_dir,
args.max_train_steps)
save_checkpoint(transformer, rank, args.output_dir, step)
main_print(f"--> checkpoint saved at step {step}")
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
Expand Down Expand Up @@ -673,6 +672,8 @@ def get_num_phases(multi_phased_distill_schedule, step):
help="The type of model to train.")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
Expand Down Expand Up @@ -922,6 +923,16 @@ def get_num_phases(multi_phased_distill_schedule, step):
default=2,
help="The stride of the discriminator head.",
)
parser.add_argument(
"--linear_range",
type=float,
default=0.5,
help="Range for linear quadratic scheduler.",
)
parser.add_argument("--weight_decay",
type=float,
default=0.001,
help="Weight decay to apply.")
parser.add_argument(
"--linear_quadratic_threshold",
type=float,
Expand Down
2 changes: 2 additions & 0 deletions fastvideo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def main(args):
help="The type of model to train.")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
Expand Down
38 changes: 32 additions & 6 deletions fastvideo/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from fastvideo.utils.logging_ import main_print


def save_checkpoint(model,
optimizer,
rank,
output_dir,
step,
discriminator=False):
def save_checkpoint_optimizer(model,
optimizer,
rank,
output_dir,
step,
discriminator=False):
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
Expand Down Expand Up @@ -60,6 +60,32 @@ def save_checkpoint(model,
torch.save(optim_state, optimizer_path)


def save_checkpoint(transformer, rank, output_dir, step):
main_print(f"--> saving checkpoint at step {step}")
with FSDP.state_dict_type(
transformer,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state = transformer.state_dict()
# todo move to get_state_dict
if rank <= 0:
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
os.makedirs(save_dir, exist_ok=True)
# save using safetensors
weight_path = os.path.join(save_dir,
"diffusion_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
config_dict = dict(transformer.config)
if "dtype" in config_dict:
del config_dict["dtype"] # TODO
config_path = os.path.join(save_dir, "config.json")
# save dict as json
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
main_print(f"--> checkpoint saved at step {step}")


def save_checkpoint_generator_discriminator(
model,
optimizer,
Expand Down
2 changes: 2 additions & 0 deletions fastvideo/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def load_transformer(
)
transformer = load_hunyuan_state_dict(transformer,
dit_model_name_or_path)
if master_weight_type == torch.bfloat16:
transformer = transformer.bfloat16()
else:
raise ValueError(f"Unsupported model type: {model_type}")
return transformer
Expand Down
6 changes: 3 additions & 3 deletions fastvideo/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import numpy as np
import torch
import wandb
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from einops import rearrange
from tqdm import tqdm

import wandb
from fastvideo.distill.solver import PCMFMScheduler
from fastvideo.models.mochi_hf.pipeline_mochi import (
linear_quadratic_schedule, retrieve_timesteps)
Expand Down Expand Up @@ -294,8 +294,8 @@ def log_validation(
scheduler_type=scheduler_type,
num_frames=args.num_frames,
# Peiyuan TODO: remove hardcode
height=480,
width=848,
height=args.num_height,
width=args.num_width,
num_inference_steps=validation_sampling_step,
guidance_scale=validation_guidance_scale,
generator=generator,
Expand Down
61 changes: 54 additions & 7 deletions scripts/distill/distill_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,69 @@ export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=online

DATA_DIR=./data
IP=[MASTER NODE IP]

torchrun --nnodes 4 --nproc_per_node 8\
--node_rank=0 \
--rdzv_id=456 \
--rdzv_backend=c10d \
--rdzv_endpoint=$IP:29500 \
fastvideo/distill.py\
--seed 42\
--pretrained_model_name_or_path $DATA_DIR/hunyuan\
--dit_model_name_or_path $DATA_DIR/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\
--model_type "hunyuan" \
--cache_dir "$DATA_DIR/.cache"\
--data_json_path "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/videos2caption.json"\
--validation_prompt_dir "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/validation"\
--gradient_checkpointing\
--train_batch_size=1\
--num_latent_t 32 \
--sp_size 2 \
--train_sp_batch_size 1\
--dataloader_num_workers 4\
--gradient_accumulation_steps=1\
--max_train_steps=320\
--learning_rate=1e-6\
--mixed_precision="bf16"\
--checkpointing_steps=64\
--validation_steps 64\
--validation_sampling_steps "2,4,8" \
--checkpoints_total_limit 3\
--allow_tf32\
--ema_start_step 0\
--cfg 0.0\
--log_validation\
--output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\
--tracker_project_name Hunyuan_Distill \
--num_height 720 \
--num_width 1280 \
--num_frames 125 \
--shift 17 \
--validation_guidance_scale "1.0" \
--num_euler_timesteps 50 \
--multi_phased_distill_schedule "4000-1" \
--not_apply_cfg_solver


# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
torchrun --nnodes 1 --nproc_per_node 8\
fastvideo/distill.py\
--seed 42\
--pretrained_model_name_or_path $DATA_DIR/hunyuan\
--dit_model_name_or_path $DATA_DIR/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\
--model_type "hunyuan" \
--cache_dir "$DATA_DIR/.cache"\
--data_json_path "$DATA_DIR/Hunyuan-30K-Distill-Data/videos2caption.json"\
--validation_prompt_dir "$DATA_DIR/Hunyuan-Distill-Data/validation"\
--data_json_path "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/videos2caption.json"\
--validation_prompt_dir "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/validation"\
--gradient_checkpointing\
--train_batch_size=1\
--num_latent_t 24\
--sp_size 1\
--num_latent_t 32 \
--sp_size 2 \
--train_sp_batch_size 1\
--dataloader_num_workers 4\
--gradient_accumulation_steps=1\
--max_train_steps=2000\
--max_train_steps=320\
--learning_rate=1e-6\
--mixed_precision="bf16"\
--checkpointing_steps=64\
Expand All @@ -30,9 +75,11 @@ torchrun --nnodes 1 --nproc_per_node 8\
--ema_start_step 0\
--cfg 0.0\
--log_validation\
--output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_32"\
--output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\
--tracker_project_name Hunyuan_Distill \
--num_frames 93 \
--num_height 720 \
--num_width 1280 \
--num_frames 125 \
--shift 17 \
--validation_guidance_scale "1.0" \
--num_euler_timesteps 50 \
Expand Down
2 changes: 2 additions & 0 deletions scripts/finetune/finetune_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,7 @@ torchrun --nnodes 1 --nproc_per_node 8 \
--output_dir=data/outputs/HSH-Taylor-Finetune-Hunyuan \
--tracker_project_name HSH-Taylor-Finetune-Hunyuan \
--num_frames 93 \
--num_height 720 \
--num_width 1280 \
--validation_guidance_scale "1.0" \
--group_frame

0 comments on commit 0c7e0bd

Please sign in to comment.