Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Training: AttributeError: 'T5EncoderModel' object has no attribute 'text_model' #1451

Closed
YujiaKCL opened this issue Aug 13, 2024 · 8 comments

Comments

@YujiaKCL
Copy link

The flux inference script is fine, but the training script will cause the following error:
AttributeError: 'T5EncoderModel' object has no attribute 'text_model'.

Does anyone encounter the same issue?

@kohya-ss
Copy link
Owner

requirements.txt is updated. Could you please update the requirements with pip install --use-pep517 --upgrade -r requirements.txt?

@YujiaKCL
Copy link
Author

requirements.txt is updated. Could you please update the requirements with pip install --use-pep517 --upgrade -r requirements.txt?

Thanks for response. I have updated the dependencies but still not working. It seems the environment is fine but the T5 model structure is not correctly configured.

For T5, I use the t5xxl_fp16.safetensors from https://huggingface.co/stabilityai/stable-diffusion-3-medium/tree/main/text_encoders. Dislike Clip-l, transformers.T5EncoderModel does not have the "text_model" attribute but "encoder".

@kohya-ss
Copy link
Owner

Could you please share the full stack trace of the error?

@YujiaKCL
Copy link
Author

Could you please share the full stack trace of the error?

Error msg and commands (T5 printed):
accelerate launch --num_processes 1 --mixed_precision bf16 --num_cpu_threads_per_process 8 flux_train_network.py
--pretrained_model_name_or_path /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/flux1-dev.sft
--clip_l /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/clip_l.safetensors
--t5xxl /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/t5xxl_fp16.safetensors
--ae /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/ae.sft
--resolution=1024,1024
--cache_latents_to_disk --cache_text_encoder_outputs_to_disk --save_model_as safetensors --sdpa
--persistent_data_loader_workers --max_data_loader_n_workers 8 --seed 42
--gradient_checkpointing --mixed_precision bf16 --save_precision bf16
--network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4
--network_train_unet_only --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1
--output_dir outputs --output_name NAME --timestep_sampling sigmoid
--model_prediction_type raw --guidance_scale 1.0 --loss_type l2
--train_data_dir testing_datasets/ds3
--output_dir outputs/test
--logging_dir logs/test
The following values were not passed to accelerate launch and had defaults used instead:
--num_machines was set to a value of 1
--dynamo_backend was set to a value of 'no'
To avoid this warning pass in values for each of the problematic parameters or run accelerate config.
/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: FutureWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils.pytree.register_pytree_node(
/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_fwd")
/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.
@torch.library.impl_abstract("xformers_flash::flash_bwd")
highvram is enabled / highvramが有効です
2024-08-13 20:14:07 WARNING cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします train_util.py:3899
/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: clean_up_tokenization_spaces was not set. It will be set to True by default. This behavior will be depracted in transformers v4.45, and will be then set to False by default. For more details check this issue: huggingface/transformers#31884
warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the legacy (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
2024-08-13 20:14:10 INFO Using DreamBooth method. train_network.py:276
WARNING ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: .ipynb_checkpoints config_util.py:589
WARNING ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: captions config_util.py:589
INFO prepare images. train_util.py:1807
INFO get image size from name of cache files train_util.py:1745
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 2048.73it/s]
INFO set image size from cache files: 100/100 train_util.py:1752
INFO found directory /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_动漫_新海诚风格 contains 100 image files train_util.py:1754
WARNING No caption file found for 100 images. Training will continue without captions for these images. If class token exists, it will be used. / train_util.py:1785
100枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。
WARNING /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_image
(1).jpg train_util.py:1792
WARNING /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_image
(10).jpg train_util.py:1792
WARNING /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_image_ (100).jpg train_util.py:1792
WARNING /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_image_ (11).jpg train_util.py:1792
WARNING /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_image_ (12).jpg train_util.py:1792
WARNING /pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_image_ (13).jpg... and 95 more train_util.py:1790
INFO 500 train images with repeating. train_util.py:1848
INFO 0 reg images. train_util.py:1851
WARNING no regularization images / 正則化画像が見つかりませんでした train_util.py:1856
INFO [Dataset 0] config_util.py:570
batch_size: 1
resolution: (1024, 1024)
enable_bucket: False
network_multiplier: 1.0

                           [Subset 0 of Dataset 0]                                                                                                                                                                                                                 
                             image_dir: "/pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3/5_images"                                                                                                                                                
                             image_count: 100                                                                                                                                                                                                                      
                             num_repeats: 5                                                                                                                                                                                                                        
                             shuffle_caption: False                                                                                                                                                                                                                
                             keep_tokens: 0                                                                                                                                                                                                                        
                             keep_tokens_separator:                                                                                                                                                                                                                
                             caption_separator: ,                                                                                                                                                                                                                  
                             secondary_separator: None                                                                                                                                                                                                             
                             enable_wildcard: False                                                                                                                                                                                                                
                             caption_dropout_rate: 0.0                                                                                                                                                                                                             
                             caption_dropout_every_n_epoches: 0                                                                                                                                                                                                    
                             caption_tag_dropout_rate: 0.0                                                                                                                                                                                                         
                             caption_prefix: None                                                                                                                                                                                                                  
                             caption_suffix: None                                                                                                                                                                                                                  
                             color_aug: False                                                                                                                                                                                                                      
                             flip_aug: False                                                                                                                                                                                                                       
                             face_crop_aug_range: None                                                                                                                                                                                                             
                             random_crop: False                                                                                                                                                                                                                    
                             token_warmup_min: 1,                                                                                                                                                                                                                  
                             token_warmup_step: 0,                                                                                                                                                                                                                 
                             alpha_mask: False,                                                                                                                                                                                                                    
                             is_reg: False                                                                                                                                                                                                                         
                             class_tokens: tok                                                                                                                                                                                                        
                             caption_extension: .caption                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                   
                INFO     [Dataset 0]                                                                                                                                                                                                             config_util.py:576
                INFO     loading image sizes.                                                                                                                                                                                                     train_util.py:876

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 2129088.32it/s]
INFO prepare dataset train_util.py:884
INFO preparing accelerator train_network.py:329
accelerator device: cuda
INFO Building CLIP flux_utils.py:48
INFO Loading state dict from /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/clip_l.safetensors flux_utils.py:141
2024-08-13 20:14:11 INFO Loaded CLIP: flux_utils.py:144
INFO Loading state dict from /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/t5xxl_fp16.safetensors flux_utils.py:187
INFO Loaded T5xxl: flux_utils.py:190
INFO Building Flux model dev flux_utils.py:23
INFO Loading state dict from /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/flux1-dev.sft flux_utils.py:28
2024-08-13 20:14:12 INFO Loaded Flux: flux_utils.py:31
INFO Building AutoEncoder flux_utils.py:36
INFO Loading state dict from /pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/ae.sft flux_utils.py:40
INFO Loaded AE: flux_utils.py:43
import network module: networks.lora_flux
2024-08-13 20:14:14 INFO [Dataset 0] train_util.py:2330
INFO caching latents with caching strategy. train_util.py:984
INFO checking cache validity... train_util.py:994
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 219.58it/s]
INFO no latents to cache train_util.py:1034
2024-08-13 20:14:19 INFO create LoRA network. base dim (rank): 4, alpha: 1 lora_flux.py:358
INFO neuron dropout: p=None, rank dropout: p=None, module dropout: p=None lora_flux.py:359
INFO create LoRA for Text Encoder 1: lora_flux.py:430
INFO create LoRA for Text Encoder 2: lora_flux.py:430
INFO create LoRA for Text Encoder: 24 modules. lora_flux.py:435
INFO create LoRA for U-Net: 304 modules. lora_flux.py:439
INFO enable LoRA for U-Net: 304 modules lora_flux.py:482
FLUX: Gradient checkpointing enabled.
prepare optimizer, data loader etc.
INFO use 8-bit AdamW optimizer | {} train_util.py:4346
override steps. steps for 4 epochs is / 指定エポックまでのステップ数: 2000
enable fp8 training.
CLIPTextModel(
(text_model): CLIPTextTransformer(
(embeddings): CLIPTextEmbeddings(
(token_embedding): Embedding(49408, 768)
(position_embedding): Embedding(77, 768)
)
(encoder): CLIPEncoder(
(layers): ModuleList(
(0-11): 12 x CLIPEncoderLayer(
(self_attn): CLIPSdpaAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(in_features=768, out_features=768, bias=True)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): CLIPMLP(
(activation_fn): GELUActivation()
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(fc2): Linear(in_features=3072, out_features=768, bias=True)
)
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
T5EncoderModel(
(shared): Embedding(32128, 4096)
(encoder): T5Stack(
(embed_tokens): Embedding(32128, 4096)
(block): ModuleList(
(0): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=4096, out_features=4096, bias=False)
(k): Linear(in_features=4096, out_features=4096, bias=False)
(v): Linear(in_features=4096, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=4096, bias=False)
(relative_attention_bias): Embedding(32, 64)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear(in_features=4096, out_features=10240, bias=False)
(wo): Linear(in_features=10240, out_features=4096, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(1-23): 23 x T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=4096, out_features=4096, bias=False)
(k): Linear(in_features=4096, out_features=4096, bias=False)
(v): Linear(in_features=4096, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=4096, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear(in_features=4096, out_features=10240, bias=False)
(wo): Linear(in_features=10240, out_features=4096, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): NewGELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(final_layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
Traceback (most recent call last):
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/flux_train_network.py", line 397, in
trainer.train(args)
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/train_network.py", line 544, in train
if hasattr(t_enc.text_model, "embeddings"):
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1729, in getattr
raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
AttributeError: 'T5EncoderModel' object has no attribute 'text_model'
Traceback (most recent call last):
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/bin/accelerate", line 8, in
sys.exit(main())
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
args.func(args)
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1106, in launch_command
simple_launcher(args)
File "/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 704, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts/venv/bin/python', 'flux_train_network.py', '--pretrained_model_name_or_path', '/pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/flux1-dev.sft', '--clip_l', '/pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/clip_l.safetensors', '--t5xxl', '/pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/t5xxl_fp16.safetensors', '--ae', '/pfs/mt-BzjfJP/yx/projs/kohya_flux/models/FLUX.1-dev/ae.sft', '--resolution=1024,1024', '--cache_latents_to_disk', '--cache_text_encoder_outputs_to_disk', '--save_model_as', 'safetensors', '--sdpa', '--persistent_data_loader_workers', '--max_data_loader_n_workers', '2', '--seed', '42', '--gradient_checkpointing', '--mixed_precision', 'bf16', '--save_precision', 'bf16', '--network_module', 'networks.lora_flux', '--network_dim', '4', '--optimizer_type', 'adamw8bit', '--learning_rate', '1e-4', '--network_train_unet_only', '--fp8_base', '--highvram', '--max_train_epochs', '4', '--save_every_n_epochs', '1', '--output_dir', 'outputs', '--output_name', 'NAME', '--timestep_sampling', 'sigmoid', '--model_prediction_type', 'raw', '--guidance_scale', '1.0', '--loss_type', 'l2', '--train_data_dir', '/pfs/mt-BzjfJP/yx/projs/kohya_flux/testing_datasets/ds3', '--output_dir', 'outputs/test', '--logging_dir', 'logs/test']' returned non-zero exit status 1.

Pip:
absl-py 2.1.0
accelerate 0.33.0
aiohappyeyeballs 2.3.5
aiohttp 3.10.3
aiosignal 1.3.1
altair 4.2.2
async-timeout 4.0.3
attrs 24.2.0
bitsandbytes 0.43.3
certifi 2024.7.4
charset-normalizer 3.3.2
diffusers 0.25.0
easygui 0.98.3
einops 0.7.0
entrypoints 0.4
filelock 3.15.4
frozenlist 1.4.1
fsspec 2024.6.1
ftfy 6.1.1
grpcio 1.65.4
huggingface-hub 0.24.5
idna 3.7
imagesize 1.4.1
importlib_metadata 8.2.0
Jinja2 3.1.4
jsonschema 4.23.0
jsonschema-specifications 2023.12.1
library 0.0.0 /pfs/mt-BzjfJP/yx/projs/kohya_flux/sd-scripts
lightning-utilities 0.11.6
lion-pytorch 0.0.6
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.5
networkx 3.3
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.20
nvidia-nvtx-cu12 12.1.105
opencv-python 4.7.0.68
packaging 24.1
pandas 2.2.2
pillow 10.4.0
pip 21.2.3
prodigyopt 1.0
protobuf 4.25.4
psutil 6.0.0
Pygments 2.18.0
python-dateutil 2.9.0.post0
pytorch-lightning 1.9.0
pytz 2024.1
PyYAML 6.0.2
referencing 0.35.1
regex 2024.7.24
requests 2.32.3
rich 13.7.0
rpds-py 0.20.0
safetensors 0.4.2
sentencepiece 0.2.0
setuptools 57.4.0
six 1.16.0
sympy 1.13.2
tensorboard 2.17.0
tensorboard-data-server 0.7.2
tokenizers 0.19.1
toml 0.10.2
toolz 0.12.1
torch 2.4.0
torchmetrics 1.4.1
torchvision 0.19.0
tqdm 4.66.5
transformers 4.44.0
triton 3.0.0
typing_extensions 4.12.2
tzdata 2024.1
urllib3 2.2.2
voluptuous 0.13.1
wcwidth 0.2.13
Werkzeug 3.0.3
xformers 0.0.27.post2
yarl 1.9.4
zipp 3.20.0

@kohya-ss
Copy link
Owner

It looks like you're using an old version of the source code. Please git pull to the latest version.

@YujiaKCL
Copy link
Author

It looks like you're using an old version of the source code. Please git pull to the latest version.

It works now. The key is to ensure cache_text_encoder_outputs=True while only set cache_text_encoder_outputs_to_disk=True will cause error.

When cache_text_encoder_outputs=False and T5 is loaded on GPU, the following lines will cause error:

if hasattr(t_enc.text_model, "embeddings"):
    # nn.Embedding not support FP8
    t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))

@kohya-ss
Copy link
Owner

Thank you for reporting the issue. I will set cache_text_encoder_outputs to True when cache_text_encoder_outputs_to_disk=True.

@kohya-ss
Copy link
Owner

I updated the code. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants