diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index fb567ebfb3..26c4cf1fdb 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -583,7 +583,7 @@ def from_pretrained( low_cpu_mem_usage=low_cpu_mem_usage, ) - model.load_adapter( + load_result = model.load_adapter( model_id, adapter_name, is_trainable=is_trainable, @@ -592,6 +592,17 @@ def from_pretrained( **kwargs, ) + # 1. Remove VB-LoRA vector bank, since it's a shared parameter set via the VBLoRAModel + # 2. Remove the prompt encoder, as it does not need to be part of the checkpoint + missing_keys = [ + k for k in load_result.missing_keys if "vblora_vector_bank" not in k and "prompt_encoder" not in k + ] + if missing_keys: + # Let's warn here since (in contrast to load_adapter) we don't return the load result, so it could be quite + # difficult for users to even notice that something might have gone wrong here. As we filter out non PEFT + # keys from the missing keys, this gives no false positives. + warnings.warn(f"Found missing adapter keys while loading the checkpoint: {missing_keys}") + return model def _setup_prompt_encoder(self, adapter_name: str): diff --git a/tests/test_initialization.py b/tests/test_initialization.py index cc54003350..7284acbb98 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1512,3 +1512,35 @@ def test_mixed_model_load_adapter_low_cpu_mem_usage_works(self, device, inputs, assert device_set_low_cpu_mem == device_set_not_low_cpu_mem assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) + + +def test_from_pretrained_missing_keys_warning(recwarn, tmp_path): + # For more context, see issue 2115 + # When loading a PEFT adapter and we're missing a PEFT-specific weight, there should be a warning. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM") + config = LoraConfig() + model = get_peft_model(model, config) + state_dict = model.state_dict() + + # first, sanity check that there are no warnings if no key is missing + model.save_pretrained(tmp_path) + del model + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM") + model = PeftModel.from_pretrained(model, tmp_path) + msg = "Found missing adapter keys" + assert not any(msg in str(w.message) for w in recwarn.list) + + # remove a key from the state_dict + missing_key = "base_model.model.model.decoder.layers.0.self_attn.v_proj.lora_A.default.weight" + + def new_state_dict(): + return {k: v for k, v in state_dict.items() if k != missing_key} + + model.state_dict = new_state_dict + model.save_pretrained(tmp_path) + del model + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM") + model = PeftModel.from_pretrained(model, tmp_path) + assert any(msg in str(w.message) for w in recwarn.list) + assert any(missing_key in str(w.message) for w in recwarn.list) diff --git a/tests/testing_common.py b/tests/testing_common.py index 860948bcfb..3eec02510f 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -18,6 +18,7 @@ import re import shutil import tempfile +import warnings from collections import OrderedDict from dataclasses import replace @@ -378,7 +379,10 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serial model.save_pretrained(tmp_dirname, safe_serialization=False) model_from_pretrained = self.transformers_class.from_pretrained(model_id) - model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + with warnings.catch_warnings(record=True) as recs: + model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + # ensure that there is no warning + assert not any("Found missing adapter keys" in str(rec.message) for rec in recs) # check if the state dicts are equal if issubclass(config_cls, PromptEncoderConfig):