Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed and DDP Configs #10

Merged
merged 6 commits into from
Oct 8, 2024
Merged

DeepSpeed and DDP Configs #10

merged 6 commits into from
Oct 8, 2024

Conversation

a-r-r-o-w
Copy link
Owner

No description provided.

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Oct 8, 2024

DeepSpeed errors out with: (cc @sayakpaul)

[rank1]: Traceback (most recent call last):                                                                                                                                                                                                                             
[rank1]:   File "/raid/aryan/cogvideox-distillation/training/cogvideox_text_to_video_lora.py", line 911, in <module>                                                                                                                                                    
[rank1]:     main(args)                                                                                                                                                                                                                                                 
[rank1]:   File "/raid/aryan/cogvideox-distillation/training/cogvideox_text_to_video_lora.py", line 684, in main                                                                                                                                                        
[rank1]:     model_output = transformer(                                                                                                                                                                                                                                
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                                                                    
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                    
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                                                                            
[rank1]:     return forward_call(*args, **kwargs)                                                                                                                                                                                                                       
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn                                                                                                                                                 
[rank1]:     ret_val = func(*args, **kwargs)                                                                                                                                                                                                                            
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1899, in forward                                                                                                                                              
[rank1]:     loss = self.module(*inputs, **kwargs)                                                                                                                                                                                                                      
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                                                                    
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                    
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                                                                            
[rank1]:     return forward_call(*args, **kwargs)                                                                                                                                                                                                                       
[rank1]:   File "/home/aryan/work/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py", line 443, in forward                                                                                                                                        
[rank1]:     emb = self.time_embedding(t_emb, timestep_cond)                                                                                                                                                                                                            
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                                                                    
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                    
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                                                                            
[rank1]:     return forward_call(*args, **kwargs)                                                                                                                                                                                                                       
[rank1]:   File "/home/aryan/work/diffusers/src/diffusers/models/embeddings.py", line 805, in forward                                                                                                                                                                   
[rank1]:     sample = self.linear_1(sample)                                                                                                                                                                                                                             
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                                                                    
[rank1]:     return self._call_impl(*args, **kwargs)                                                                                                                                                                                                                    
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                                                                            
[rank1]:     return forward_call(*args, **kwargs)                                                                                                                                                                                                                       
[rank1]:   File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward                                                                                                                                                
[rank1]:     return F.linear(input, self.weight, self.bias)                                                                                                                                                                                                             
[rank1]: RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

DDP + uncompiled: works

DDP + compiled: does not work. I don't think this setting has ever worked for me, or that they are compatible with each other (seems like so from some quick googling)

@sayakpaul
Copy link
Collaborator

DeepSpeed errors out with:

Details unfold to something blank.

@a-r-r-o-w
Copy link
Owner Author

Oh sorry, really weird! Updated

@sayakpaul
Copy link
Collaborator

Thanks. How can I reproduce the error?

@@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not do multi-GPU for now I guess to keep things as simple as possible. Or is there anything I am missing?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have configs for single GPU compiled/uncompiled. This is for testing DDP on 2 GPUs (which works with uncompiled, but not for compiled - I'm investigating what's wrong)

@a-r-r-o-w
Copy link
Owner Author

Thanks. How can I reproduce the error?

I think you just have to change the config in the train_text_to_video_lora.sh file to use the DeepSpeed one.

This is what I'm using for example (from the root folder of the repo):

export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0

GPU_IDS="2,3"

DATA_ROOT="training/dump"

CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"

cmd="accelerate launch --config_file accelerate_configs/deepspeed.yaml --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
  --pretrained_model_name_or_path THUDM/CogVideoX-5b \
  --data_root $DATA_ROOT \
  --caption_column $CAPTION_COLUMN \
  --video_column $VIDEO_COLUMN \
  --id_token BW_STYLE \
  --height_buckets 480 \
  --width_buckets 720 \
  --frame_buckets 49 \
  --load_tensors \
  --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \
  --validation_prompt_separator ::: \
  --num_validation_videos 1 \
  --validation_epochs 1 \
  --seed 42 \
  --rank 64 \
  --lora_alpha 64 \
  --mixed_precision bf16 \
  --output_dir /raid/aryan/cogvideox-lora \
  --max_num_frames 49 \
  --train_batch_size 1 \
  --max_train_steps 3000 \
  --checkpointing_steps 1000 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --learning_rate 0.0001 \
  --lr_scheduler constant \
  --lr_warmup_steps 200 \
  --lr_num_cycles 1 \
  --enable_slicing \
  --enable_tiling \
  --optimizer adamw \
  --beta1 0.9 \
  --beta2 0.95 \
  --beta3 0.99 \
  --weight_decay 0.001 \
  --max_grad_norm 1.0 \
  --allow_tf32 \
  --report_to wandb \
  --nccl_timeout 1800"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

@sayakpaul
Copy link
Collaborator

@a-r-r-o-w DeepSpeed seems to be working.

Patch:
diff --git a/training/cogvideox_text_to_video_lora.py b/training/cogvideox_text_to_video_lora.py
index fa6b6e0..c2a29d6 100644
--- a/training/cogvideox_text_to_video_lora.py
+++ b/training/cogvideox_text_to_video_lora.py
@@ -315,7 +315,7 @@ def main(args):
             "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
             and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
         ):
-            weight_dtype = torch.float16
+            weight_dtype = torch.bfloat16
     else:
         if accelerator.mixed_precision == "fp16":
             weight_dtype = torch.float16
@@ -631,7 +631,7 @@ def main(args):
 
                 videos = latent_dist.sample() * VAE_SCALING_FACTOR
                 videos = videos.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
-                videos = videos.to(memory_format=torch.contiguous_format).float()
+                videos = videos.to(memory_format=torch.contiguous_format).to(weight_dtype)
                 model_input = videos
 
                 # Encode prompts
@@ -646,7 +646,7 @@ def main(args):
                         requires_grad=False,
                     )
                 else:
-                    prompt_embeds = prompts
+                    prompt_embeds = prompts.to(weight_dtype)
 
                 # Sample noise that will be added to the latents
                 noise = torch.randn_like(model_input)
accelerate config
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Training command
# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0

GPU_IDS="2"

DATA_ROOT="video-dataset-disney"

CAPTION_COLUMN="prompt.txt"
VIDEO_COLUMN="videos.txt"

cmd="accelerate launch --config_file accelerate_configs/deepspeed.yaml --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
  --pretrained_model_name_or_path THUDM/CogVideoX-5b \
  --data_root $DATA_ROOT \
  --caption_column $CAPTION_COLUMN \
  --video_column $VIDEO_COLUMN \
  --id_token BW_STYLE \
  --height_buckets 480 \
  --width_buckets 720 \
  --frame_buckets 49 \
  --load_tensors \
  --validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions\" \
  --validation_prompt_separator ::: \
  --num_validation_videos 1 \
  --validation_epochs 1 \
  --seed 42 \
  --rank 64 \
  --lora_alpha 64 \
  --mixed_precision bf16 \
  --output_dir lora \
  --max_num_frames 49 \
  --train_batch_size 1 \
  --max_train_steps 3000 \
  --checkpointing_steps 1000 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --learning_rate 0.0001 \
  --lr_scheduler constant \
  --lr_warmup_steps 200 \
  --lr_num_cycles 1 \
  --enable_slicing \
  --enable_tiling \
  --optimizer adamw \
  --beta1 0.9 \
  --beta2 0.95 \
  --beta3 0.99 \
  --weight_decay 0.001 \
  --max_grad_norm 1.0 \
  --allow_tf32 \
  --report_to wandb \
  --nccl_timeout 1800"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

@a-r-r-o-w a-r-r-o-w merged commit be9d99a into main Oct 8, 2024
@a-r-r-o-w a-r-r-o-w deleted the deepspeed-and-ddp-configs branch October 8, 2024 21:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants