Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tracker] [bnb] Supporting device_map containing GPU and CPU devices #19090

Closed
younesbelkada opened this issue Sep 17, 2022 · 20 comments
Closed
Assignees

Comments

@younesbelkada
Copy link
Contributor

Feature request

We should be able to provide custom device_map when using 8-bit models using bitsandbytes. This would enable users having more control over the modules they want to quantize.

Linked issue: bitsandbytes-foundation/bitsandbytes#40

Motivation

Users should be able to pass their own custom device_map and chose which module should be quantized or not

Your contribution

Try coding this enhancement!

@z80maniac
Copy link

z80maniac commented Sep 18, 2022

UPDATE (for future readers): the title was changed.


I think that the title of this issue is a little bit misleading. Technically, a custom device_map is already supported for bitsandbytes, as long as all the layers are on GPU.

For example, in the linked issue, this device_map works correctly:

    device_map = {
        "transformer.wte": 0,
        "transformer.wpe": 0,
        "transformer.ln_f": 0,
        "lm_head": 0,
        "transformer.h.0": 0,
        "transformer.h.1": 0,
        "transformer.h.2": 0,
        "transformer.h.3": 0,
        "transformer.h.4": 0,
        "transformer.h.5": 0,
        "transformer.h.6": 0,
        "transformer.h.7": 0,
        "transformer.h.8": 0,
        "transformer.h.9": 0,
        "transformer.h.10": 0,
        "transformer.h.11": 0
    }

And I believe that there will be no problem in using 1 instead of 0 for any transformer.* layer if you have more than one GPU (but I may be mistaken, I didn't find any specific info in any docs about using bitsandbytes with multiple GPUs). And I suppose that replacing all 0 with 1 will also work. So, I think that users already can customize the device map, as long as it doesn't put anything on CPU.

The original issue was not about a custom map. It was about supporting the load_in_8bit flag for models that are shared between CPU and GPU.

@younesbelkada younesbelkada changed the title [Tracker] [bnb] Supporting custom device_map [Tracker] [bnb] Supporting device_map containing GPU and CPU devices Sep 18, 2022
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@z80maniac
Copy link

If you think this still needs to be addressed please comment on this thread.

unstale

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@z80maniac
Copy link

If you think this still needs to be addressed please comment on this thread.

unstale

I guess this will be my monthly routine...

@younesbelkada
Copy link
Contributor Author

Hi
The PR #20281 will not be merged until a fix will be found on bitsandbytes side.
Could you please checkout from this PR if you want to use this feature from now? Thanks.

@z80maniac
Copy link

I've just tested that PR and it works. Thank you!

I tested it with a 13B model on GTX 3060. Without load_in_8bit only 10 layers are able to fit into the GPU. With that patch and load_in_8bit=True now 19 layers are able to fit into the GPU. Which gives a 30% speedup of the inference in my case.

For some reason when I test it on my initial example, it gives this warning:

/home/user/test/bnb-test/transformers/src/transformers/generation/utils.py:1470: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
  warnings.warn(

However, I was not able to reproduce it in my other more complex program.

In the PR's discussion it was said:

this will result in weights offloaded on the CPU to not be converted in int8 at all

I expected this much, but I think it's still better than nothing.

Though, are there some gotchas in the fact that CPU layers are not converted to 8bit?

Also, not sure how to proceed next. You said:

we should probably wait until bitsandbytes supports weights offloading in 8-bit to add this feature

So I suppose this issue should remain open? I will then add more info to my initial issue at the bitsandbytes repo.

@younesbelkada
Copy link
Contributor Author

Thank you very much for your feedback and happy that it worked for your usecase!

For some reason when I test it on my initial example, it gives this warning:

This is because you have set your input_ids on the cpu before running your inference! Make sure to set input_ids to the device of the first layers (so I guess here, your GPU) before running generate.

Though, are there some gotchas in the fact that CPU layers are not converted to 8bit?

I did not quite get your question here, but CPU layers are kept in their native dtype here indeed, which can be quite confusing. For example you could provide a device_map that contains only cpu layers and still load your model with load_in_8bit - users will think that they're loading their model in 8-bit on their CPU when actually it's not the case.

So I suppose this issue should remain open? I will then add more info to my initial issue at the bitsandbytes repo.

Yes, it can remain open. But feel free also to jump in the PR #20281 to give your opinion on the question and stress about the fact that you think this feature is useful. You can also add more information on the bitsandbytes repo also!

@z80maniac
Copy link

This is because you have set your input_ids on the cpu before running your inference! Make sure to set input_ids to the device of the first layers (so I guess here, your GPU) before running generate.

I use the following code:

pipe = pipeline(
    model="EleutherAI/gpt-neo-125M",
    max_length=32,
    model_kwargs={
        "device_map": device_map,
        "load_in_8bit": load_in_8bit
    }
)

print("\n", pipe("It was")[0]["generated_text"])

Not sure where I am supposed to set input_ids here.

I did not quite get your question here

I mean, purely from a technical standpoint, are there some downsides to mixing 8bit and 16/32bit layers?

@younesbelkada
Copy link
Contributor Author

Not sure where I am supposed to set input_ids here.

Thanks for sharing the code! It's clearer for me now, can you try to add device=0 as follows:

pipe = pipeline(
    model="EleutherAI/gpt-neo-125M",
    max_length=32,
   device=0,
    model_kwargs={
        "device_map": device_map,
        "load_in_8bit": load_in_8bit
    }

)

I mean, purely from a technical standpoint, are there some downsides to mixing 8bit and 16/32bit layers?

Indeed, from a technical standpoint I don't see any downside

@z80maniac
Copy link

When I add device=0 I get this:

Traceback (most recent call last):
  File "/home/user/test/bnb-test/main.py", line 28, in <module>
    pipe = pipeline(
  File "/home/user/test/bnb-test/transformers/src/transformers/pipelines/__init__.py", line 870, in pipeline
    return pipeline_class(model=model, framework=framework, task=task, **kwargs)
  File "/home/user/test/bnb-test/transformers/src/transformers/pipelines/text_generation.py", line 64, in __init__
    super().__init__(*args, **kwargs)
  File "/home/user/test/bnb-test/transformers/src/transformers/pipelines/base.py", line 778, in __init__
    self.model = self.model.to(self.device)
  File "/home/user/test/bnb-test/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 987, in to
    return self._apply(convert)
  File "/home/user/test/bnb-test/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 639, in _apply
    module._apply(fn)
  File "/home/user/test/bnb-test/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 639, in _apply
    module._apply(fn)
  File "/home/user/test/bnb-test/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 639, in _apply
    module._apply(fn)
  [Previous line repeated 1 more time]
  File "/home/user/test/bnb-test/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 662, in _apply
    param_applied = fn(param)
  File "/home/user/test/bnb-test/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 985, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

The full code for clarity:

from transformers import pipeline

auto_map = False
load_in_8bit = True

if auto_map:
    device_map = "auto"
else:
    device_map = {
        "transformer.wte": 0,
        "transformer.wpe": 0,
        "transformer.ln_f": "cpu",
        "lm_head": 0,
        "transformer.h.0": 0,
        "transformer.h.1": "cpu",
        "transformer.h.2": "cpu",
        "transformer.h.3": "cpu",
        "transformer.h.4": "cpu",
        "transformer.h.5": "cpu",
        "transformer.h.6": "cpu",
        "transformer.h.7": "cpu",
        "transformer.h.8": "cpu",
        "transformer.h.9": "cpu",
        "transformer.h.10": "cpu",
        "transformer.h.11": "cpu"
    }

pipe = pipeline(
    model="EleutherAI/gpt-neo-125M",
    device=0,
    max_length=32,
    model_kwargs={
        "device_map": device_map,
        "load_in_8bit": load_in_8bit
    }
)

print("\n", pipe("It was")[0]["generated_text"])

The error occurs even when load_in_8bit = False.

Also, in any case, the original error is pretty confusing. It says You are calling .generate() with the input_ids, but I don't do such a thing.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Nov 18, 2022

Thanks for sharing, I think it is fine, for now I would say that you can leave the pipeline without device=0. I expect a small speedup since accelerate copies the input_ids that is created on the cpu to the device of the model at the beginning, and copies back the result on cpu. Let me get back to you on this to see if I can find a solution

the reason it says generate() is because pipeline calls .generate() under the hood here

@z80maniac
Copy link

the reason it says generate() is because pipeline calls .generate() under the hood here

I know, but to an end user it still will not be immediately clear what the problem is just by reading that error message. It also says how to fix it:

Please make sure that you have put input_ids to the correct device
by calling for example input_ids = input_ids.to('cuda') before running .generate()

But it's absolutely not applicable in this situation, adding even more confusion. Maybe the call to pipeline should have a different error message?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@z80maniac
Copy link

unstale

Also, I added some comments in the PR discussion:
#20281 (comment)
#20281 (comment)

@github-actions
Copy link

github-actions bot commented Jan 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@z80maniac
Copy link

unstale

Technically, I personally don't need this fix anymore, since in my project I applied the hack described in the PR.
Though it would be nice to have it properly integrated into the transformers.

@huggingface huggingface deleted a comment from github-actions bot Feb 2, 2023
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@younesbelkada
Copy link
Contributor Author

This should be solved by the introduction of BitsAndBytesConfig in #21579

@z80maniac
Copy link

Yes, indeed it works. Thank you, @younesbelkada!

For completeness sake, here's the final working version:

import torch
from transformers import BitsAndBytesConfig, pipeline

device_map = {
    "transformer.wte": 0,
    "transformer.wpe": 0,
    "transformer.ln_f": "cpu",
    "lm_head": 0,
    "transformer.h.0": 0,
    "transformer.h.1": "cpu",
    "transformer.h.2": "cpu",
    "transformer.h.3": "cpu",
    "transformer.h.4": "cpu",
    "transformer.h.5": "cpu",
    "transformer.h.6": "cpu",
    "transformer.h.7": "cpu",
    "transformer.h.8": "cpu",
    "transformer.h.9": "cpu",
    "transformer.h.10": "cpu",
    "transformer.h.11": "cpu"
}


quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True,
    llm_int8_skip_modules=["lm_head"]
)

pipe = pipeline(
    model="EleutherAI/gpt-neo-125M",
    max_length=32,
    torch_dtype=torch.float16,
    model_kwargs={
        "device_map": device_map,
        "quantization_config": quantization_config
    }
)

print("\n", pipe("It was")[0]["generated_text"])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants