Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Save CK, Dataset bug fix #125

Merged
merged 3 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ For Image-Video Mixture Fine-tuning, make sure to enable the --group_frame optio
- [ ] fp8 support
- [ ] faster load model and save model support

## Contributing

We welcome all contributions. Please run bash format.sh before submitting a pull request.

## Acknowledgement
We learned and reused code from the following projects: [PCM](https://github.com/G-U-N/Phased-Consistency-Model), [diffusers](https://github.com/huggingface/diffusers), [OpenSoraPlan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), and [xDiT](https://github.com/xdit-project/xDiT).

Expand Down
2 changes: 2 additions & 0 deletions fastvideo/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,8 @@ def get_num_phases(multi_phased_distill_schedule, step):

# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
Expand Down
19 changes: 15 additions & 4 deletions fastvideo/distill_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ def main(args):
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=1e-3,
weight_decay=args.weight_decay,
eps=1e-8,
)

discriminator_optimizer = torch.optim.AdamW(
discriminator.parameters(),
lr=args.discriminator_learning_rate,
betas=(0, 0.999),
weight_decay=1e-3,
weight_decay=args.weight_decay,
eps=1e-8,
)

Expand Down Expand Up @@ -634,8 +634,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
# args.output_dir,
# step,
# )
save_checkpoint(transformer, rank, args.output_dir,
args.max_train_steps)
save_checkpoint(transformer, rank, args.output_dir, step)
main_print(f"--> checkpoint saved at step {step}")
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
Expand Down Expand Up @@ -673,6 +672,8 @@ def get_num_phases(multi_phased_distill_schedule, step):
help="The type of model to train.")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
Expand Down Expand Up @@ -922,6 +923,16 @@ def get_num_phases(multi_phased_distill_schedule, step):
default=2,
help="The stride of the discriminator head.",
)
parser.add_argument(
"--linear_range",
type=float,
default=0.5,
help="Range for linear quadratic scheduler.",
)
parser.add_argument("--weight_decay",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code format?

type=float,
default=0.001,
help="Weight decay to apply.")
parser.add_argument(
"--linear_quadratic_threshold",
type=float,
Expand Down
2 changes: 2 additions & 0 deletions fastvideo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def main(args):
help="The type of model to train.")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
Expand Down
38 changes: 32 additions & 6 deletions fastvideo/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from fastvideo.utils.logging_ import main_print


def save_checkpoint(model,
optimizer,
rank,
output_dir,
step,
discriminator=False):
def save_checkpoint_optimizer(model,
optimizer,
rank,
output_dir,
step,
discriminator=False):
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
Expand Down Expand Up @@ -60,6 +60,32 @@ def save_checkpoint(model,
torch.save(optim_state, optimizer_path)


def save_checkpoint(transformer, rank, output_dir, step):
main_print(f"--> saving checkpoint at step {step}")
with FSDP.state_dict_type(
transformer,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state = transformer.state_dict()
# todo move to get_state_dict
if rank <= 0:
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
os.makedirs(save_dir, exist_ok=True)
# save using safetensors
weight_path = os.path.join(save_dir,
"diffusion_pytorch_model.safetensors")
save_file(cpu_state, weight_path)
config_dict = dict(transformer.config)
if "dtype" in config_dict:
del config_dict["dtype"] # TODO
config_path = os.path.join(save_dir, "config.json")
# save dict as json
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
main_print(f"--> checkpoint saved at step {step}")


def save_checkpoint_generator_discriminator(
model,
optimizer,
Expand Down
2 changes: 2 additions & 0 deletions fastvideo/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def load_transformer(
)
transformer = load_hunyuan_state_dict(transformer,
dit_model_name_or_path)
if master_weight_type == torch.bfloat16:
transformer = transformer.bfloat16()
else:
raise ValueError(f"Unsupported model type: {model_type}")
return transformer
Expand Down
6 changes: 3 additions & 3 deletions fastvideo/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import numpy as np
import torch
import wandb
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from einops import rearrange
from tqdm import tqdm

import wandb
from fastvideo.distill.solver import PCMFMScheduler
from fastvideo.models.mochi_hf.pipeline_mochi import (
linear_quadratic_schedule, retrieve_timesteps)
Expand Down Expand Up @@ -294,8 +294,8 @@ def log_validation(
scheduler_type=scheduler_type,
num_frames=args.num_frames,
# Peiyuan TODO: remove hardcode
height=480,
width=848,
height=args.num_height,
width=args.num_width,
num_inference_steps=validation_sampling_step,
guidance_scale=validation_guidance_scale,
generator=generator,
Expand Down
61 changes: 54 additions & 7 deletions scripts/distill/distill_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,69 @@ export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=online

DATA_DIR=./data
IP=[MASTER NODE IP]

torchrun --nnodes 4 --nproc_per_node 8\
--node_rank=0 \
--rdzv_id=456 \
--rdzv_backend=c10d \
--rdzv_endpoint=$IP:29500 \
fastvideo/distill.py\
--seed 42\
--pretrained_model_name_or_path $DATA_DIR/hunyuan\
--dit_model_name_or_path $DATA_DIR/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\
--model_type "hunyuan" \
--cache_dir "$DATA_DIR/.cache"\
--data_json_path "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/videos2caption.json"\
--validation_prompt_dir "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/validation"\
--gradient_checkpointing\
--train_batch_size=1\
--num_latent_t 32 \
--sp_size 2 \
--train_sp_batch_size 1\
--dataloader_num_workers 4\
--gradient_accumulation_steps=1\
--max_train_steps=320\
--learning_rate=1e-6\
--mixed_precision="bf16"\
--checkpointing_steps=64\
--validation_steps 64\
--validation_sampling_steps "2,4,8" \
--checkpoints_total_limit 3\
--allow_tf32\
--ema_start_step 0\
--cfg 0.0\
--log_validation\
--output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\
--tracker_project_name Hunyuan_Distill \
--num_height 720 \
--num_width 1280 \
--num_frames 125 \
--shift 17 \
--validation_guidance_scale "1.0" \
--num_euler_timesteps 50 \
--multi_phased_distill_schedule "4000-1" \
--not_apply_cfg_solver


# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t
torchrun --nnodes 1 --nproc_per_node 8\
fastvideo/distill.py\
--seed 42\
--pretrained_model_name_or_path $DATA_DIR/hunyuan\
--dit_model_name_or_path $DATA_DIR/hunyuan/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt\
--model_type "hunyuan" \
--cache_dir "$DATA_DIR/.cache"\
--data_json_path "$DATA_DIR/Hunyuan-30K-Distill-Data/videos2caption.json"\
--validation_prompt_dir "$DATA_DIR/Hunyuan-Distill-Data/validation"\
--data_json_path "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/videos2caption.json"\
--validation_prompt_dir "$DATA_DIR/HD-Mixkit-Finetune-Hunyuan/validation"\
--gradient_checkpointing\
--train_batch_size=1\
--num_latent_t 24\
--sp_size 1\
--num_latent_t 32 \
--sp_size 2 \
--train_sp_batch_size 1\
--dataloader_num_workers 4\
--gradient_accumulation_steps=1\
--max_train_steps=2000\
--max_train_steps=320\
--learning_rate=1e-6\
--mixed_precision="bf16"\
--checkpointing_steps=64\
Expand All @@ -30,9 +75,11 @@ torchrun --nnodes 1 --nproc_per_node 8\
--ema_start_step 0\
--cfg 0.0\
--log_validation\
--output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_32"\
--output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\
--tracker_project_name Hunyuan_Distill \
--num_frames 93 \
--num_height 720 \
--num_width 1280 \
--num_frames 125 \
--shift 17 \
--validation_guidance_scale "1.0" \
--num_euler_timesteps 50 \
Expand Down
2 changes: 2 additions & 0 deletions scripts/finetune/finetune_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,7 @@ torchrun --nnodes 1 --nproc_per_node 8 \
--output_dir=data/outputs/HSH-Taylor-Finetune-Hunyuan \
--tracker_project_name HSH-Taylor-Finetune-Hunyuan \
--num_frames 93 \
--num_height 720 \
--num_width 1280 \
--validation_guidance_scale "1.0" \
--group_frame
Loading