Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Allow rank/alpha keys to be "fully qualified" #2382

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,13 @@ class LoraConfig(PeftConfig):
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
rank_pattern (`dict`):
The mapping from layer names or regexp expression to ranks which are different from the default rank
specified by `r`.
specified by `r`. Note that the keys in the keys in the dict are interpreted as pattern, so `"bar"` will
also match `"foo.bar"`. To prevent this, add a special prefix before the key: `f"FULL-NAME-{key}"`.
alpha_pattern (`dict`):
The mapping from layer names or regexp expression to alphas which are different from the default alpha
specified by `lora_alpha`.
specified by `lora_alpha`. Note that the keys in the keys in the dict are interpreted as pattern, so
`"bar"` will also match `"foo.bar"`. To prevent this, add a special prefix before the key:
`f"FULL-NAME-{key}"`.
megatron_config (`Optional[dict]`):
The TransformerConfig arguments for Megatron. It is used to create LoRA's parallel linear layer. You can
get it like this, `core_transformer_config_from_args(get_args())`, these two functions being from Megatron.
Expand Down Expand Up @@ -399,17 +402,21 @@ class LoraConfig(PeftConfig):
default_factory=dict,
metadata={
"help": (
"The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
"The mapping from layer names or regexp expression to ranks which are different from the default rank "
"specified by `r`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}. Note that the keys "
"in the keys in the dict are interpreted as pattern, so `'bar'` will also match `'foo.bar'`. To "
"prevent this, add a special prefix before the key: `f'FULL-NAME-{key}'`."
)
},
)
alpha_pattern: Optional[dict] = field(
default_factory=dict,
metadata={
"help": (
"The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. "
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
"The mapping from layer names or regexp expression to alphas which are different from the default alpha "
"specified by `lora_alpha`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`} Note that the keys "
"in the keys in the dict are interpreted as pattern, so `'bar'` will also match `'foo.bar'`. To "
"prevent this, add a special prefix before the key: `f'FULL-NAME-{key}'`."
)
},
)
Expand Down
8 changes: 8 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,11 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
# otherwise there is no point in optimizing and there is a small chance of bugs in the optimization algorithm, so no
# point in taking unnecessary risks. See #2045 for more context.
MIN_TARGET_MODULES_FOR_OPTIMIZATION = 20

# Use this prefix for rank_pattern or alpha_pattern if the key is meant to be the fully qualified key, not a pattern. So
# e.g.:
# `rank_pattern = {"foo": 16}`
# would match model.foo but also model.inner_model.foo. This can be avoided by setting the key as:
# `rank_pattern = {f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}foo": 16}`
# This is only for rank_pattern or alpha_pattern, it is not intended for target_modules
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = "FULL-NAME-"
11 changes: 11 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .constants import (
CONFIG_NAME,
EMBEDDING_LAYER_NAMES,
FULLY_QUALIFIED_PATTERN_KEY_PREFIX,
INCLUDE_LINEAR_LAYERS_SHORTHAND,
SAFETENSORS_WEIGHTS_NAME,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
Expand Down Expand Up @@ -1054,4 +1055,14 @@ def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Option

def get_pattern_key(pattern_keys, key_to_match):
"""Match a substring of key_to_match in pattern keys"""
# handling of a special case: Users can prefix the key in rank_pattern or alpha_pattern with this special prefix to
# indicate that this key is supposed to be the full name, not a pattern. That way, the key "foo" can be matched
# without inadvertently matching bar.foo as well.
for key in pattern_keys:
if (
key.startswith(FULLY_QUALIFIED_PATTERN_KEY_PREFIX)
and key[len(FULLY_QUALIFIED_PATTERN_KEY_PREFIX) :] == key_to_match
):
return key

return next(filter(lambda key: re.match(rf".*\.{key}$", key_to_match), pattern_keys), key_to_match)
Comment on lines +1058 to 1068
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the introduction of a special prefix is necessary. We're already utilizing regex here, why not use the caret operator to force a full match? The current pattern doesn't allow this but I think it is easily changed.

Proof of concept:

import re

def match(pattern, name):
    return re.match(rf".*(^|\.){pattern}$", name)

assert match("foo", "model.foo")
assert not match("foo", "model.bar")
assert not match("foo", "bofoo")
assert not match("foo", "model.bofoo")

assert match("^foo", "foo")
assert not match("^foo", "model.foo")

113 changes: 112 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import AuxiliaryTrainingWrapper, infer_device
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX

from .testing_common import PeftCommonTester
from .testing_utils import get_state_dict, require_non_cpu
Expand Down Expand Up @@ -1952,7 +1953,7 @@ def test_mha_gradients_set_correctly(self, with_forward_call):
assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True


class TestMultiRankAdapter(unittest.TestCase):
class TestMultiRankAdapter:
"""Tests related to multirank LoRA adapters"""

def test_multirank(self):
Expand Down Expand Up @@ -2019,6 +2020,116 @@ def test_multirank_2(self):
f"Rank {rank_current} is not equal to expected {rank_expected}"
)

#################
# NESTED MODELS #
#################

# see https://github.com/huggingface/diffusers/pull/10808 for discussion

# We have an issue if one key of the rank_pattern the suffix of both a key of a module that needs the specific rank
# and the key of a module that does NOT need the specific rank. In general, most real world models are nested, but
# not all have sub-modules that share the same name. For some diffusion models, however, this is a possible issue,
# so we add the possibility to indicate that a key should be considered "fully qualified" as opposed to a pattern to
# be matched.

def get_nested_model(self):
class Inner(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(3, 4)
self.lin1 = nn.Linear(4, 3)

class Outer(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(5, 6) # same name as for Inner
self.lin1 = nn.Linear(6, 5) # same name as for Inner
self.inner = Inner()

return Outer()

def test_nested_adapter_rank_pattern_applied_inner_and_outer(self):
# first sanity check that the r from rank_pattern can be applied to both "lin0" modules
model = self.get_nested_model()
config = LoraConfig(target_modules=["lin0", "lin1"], r=8, rank_pattern={"lin0": 16})
model = get_peft_model(model, config)

assert model.base_model.model.lin0.lora_A["default"].weight.shape[0] == 16
assert model.base_model.model.lin0.lora_B["default"].weight.shape[1] == 16
assert model.base_model.model.lin1.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.lin1.lora_B["default"].weight.shape[1] == 8
assert model.base_model.model.inner.lin0.lora_A["default"].weight.shape[0] == 16
assert model.base_model.model.inner.lin0.lora_B["default"].weight.shape[1] == 16
assert model.base_model.model.inner.lin1.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.inner.lin1.lora_B["default"].weight.shape[1] == 8

def test_nested_adapter_rank_pattern_applied_only_inner(self):
# applying the special rank to the inner module is possible by prefixing the name of the inner module
model = self.get_nested_model()
config = LoraConfig(target_modules=["lin0", "lin1"], r=8, rank_pattern={"inner.lin0": 16})
model = get_peft_model(model, config)
assert model.base_model.model.lin0.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.lin0.lora_B["default"].weight.shape[1] == 8
assert model.base_model.model.lin1.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.lin1.lora_B["default"].weight.shape[1] == 8
assert model.base_model.model.inner.lin0.lora_A["default"].weight.shape[0] == 16
assert model.base_model.model.inner.lin0.lora_B["default"].weight.shape[1] == 16
assert model.base_model.model.inner.lin1.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.inner.lin1.lora_B["default"].weight.shape[1] == 8

def test_nested_adapter_rank_pattern_applied_only_outer(self):
# This used to be impossible before, as the inner key would also match the key of the outer module. With the
# addition of the prefix, this is now possible.
model = self.get_nested_model()
config = LoraConfig(
target_modules=["lin0", "lin1"], r=8, rank_pattern={FULLY_QUALIFIED_PATTERN_KEY_PREFIX + "lin0": 16}
)
model = get_peft_model(model, config)
assert model.base_model.model.lin0.lora_A["default"].weight.shape[0] == 16
assert model.base_model.model.lin0.lora_B["default"].weight.shape[1] == 16
assert model.base_model.model.lin1.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.lin1.lora_B["default"].weight.shape[1] == 8
assert model.base_model.model.inner.lin0.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.inner.lin0.lora_B["default"].weight.shape[1] == 8
assert model.base_model.model.inner.lin1.lora_A["default"].weight.shape[0] == 8
assert model.base_model.model.inner.lin1.lora_B["default"].weight.shape[1] == 8

def test_nested_adapter_rank_pattern_applied_to_all(self):
# This used to be impossible before, as the inner key would also match the key of the outer module. With the
# addition of the prefix, this is now possible.
model = self.get_nested_model()
rank_pattern = {
FULLY_QUALIFIED_PATTERN_KEY_PREFIX + "lin0": 10,
FULLY_QUALIFIED_PATTERN_KEY_PREFIX + "lin1": 11,
FULLY_QUALIFIED_PATTERN_KEY_PREFIX + "inner.lin0": 12,
FULLY_QUALIFIED_PATTERN_KEY_PREFIX + "inner.lin1": 13,
}
config = LoraConfig(target_modules=["lin0", "lin1"], r=8, rank_pattern=rank_pattern)
model = get_peft_model(model, config)
assert model.base_model.model.lin0.lora_A["default"].weight.shape[0] == 10
assert model.base_model.model.lin0.lora_B["default"].weight.shape[1] == 10
assert model.base_model.model.lin1.lora_A["default"].weight.shape[0] == 11
assert model.base_model.model.lin1.lora_B["default"].weight.shape[1] == 11
assert model.base_model.model.inner.lin0.lora_A["default"].weight.shape[0] == 12
assert model.base_model.model.inner.lin0.lora_B["default"].weight.shape[1] == 12
assert model.base_model.model.inner.lin1.lora_A["default"].weight.shape[0] == 13
assert model.base_model.model.inner.lin1.lora_B["default"].weight.shape[1] == 13

def test_nested_adapter_alpha_pattern_applied_only_outer(self):
# The same test as above but using alpha pattern. As the two patterns use the same logic, they should work
# equally well, but let's test explicitly to be sure
model = self.get_nested_model()
config = LoraConfig(
target_modules=["lin0", "lin1"],
lora_alpha=8,
alpha_pattern={FULLY_QUALIFIED_PATTERN_KEY_PREFIX + "lin0": 16},
)
model = get_peft_model(model, config)
assert model.base_model.model.lin0.lora_alpha["default"] == 16
assert model.base_model.model.lin1.lora_alpha["default"] == 8
assert model.base_model.model.inner.lin0.lora_alpha["default"] == 8
assert model.base_model.model.inner.lin0.lora_alpha["default"] == 8


class TestRepr(unittest.TestCase):
"""Tests related to the repr of adapted models"""
Expand Down
Loading