Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Tokenizer adds whitespace after special token #25073

Closed
2 of 4 tasks
g588928812 opened this issue Jul 25, 2023 · 13 comments
Closed
2 of 4 tasks

Slow Tokenizer adds whitespace after special token #25073

g588928812 opened this issue Jul 25, 2023 · 13 comments

Comments

@g588928812
Copy link

g588928812 commented Jul 25, 2023

System Info

Python 3.10.6
Transformers 4.31.0
<class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer
import transformers

tokenizer = AutoTokenizer.from_pretrained(
	"../models/llama-2-7b",
	use_fast=False,
)

txt="this is one sentence." + tokenizer.eos_token + "this is another sentence." + tokenizer.eos_token + "this is the third sentence." + tokenizer.eos_token

txt_encoded = tokenizer.encode(txt, add_special_tokens=False)
txt_encoded_decoded = tokenizer.decode(txt_encoded)
txt_encoded_decoded_spaces_false = tokenizer.decode(txt_encoded, spaces_between_special_tokens=False)

print(transformers.__version__)
print(tokenizer.__class__)

print(f"INPUT:\n{txt}\n")
print(f"ROUNDTRIP:\n{txt_encoded_decoded}\n")
print(f"ROUNDTRIP w/ spaces_between_special_tokens=F:\n{txt_encoded_decoded}\n")

Output:

You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
4.31.0
<class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
INPUT:
this is one sentence.</s>this is another sentence.</s>this is the third sentence.</s>

ROUNDTRIP:
 this is one sentence.</s> this is another sentence.</s> this is the third sentence.</s>

ROUNDTRIP w/ spaces_between_special_tokens=F:
 this is one sentence.</s> this is another sentence.</s> this is the third sentence.</s>

Expected behavior

txt == txt_encoded_decoded

I expect text to be the same as decode(encode(text)), however a whitespace is added after each special token (</s>). From what I saw in previous issues, spaces_between_special_tokens=F should change that but it does not, whitespaces are still there.

What am I missing?

Thank you for your help and apologies in advance, this issue seems to come up quite often and I spent quite some time going through issues in this repo but nothing solved it for me.

@ArthurZucker
Copy link
Collaborator

Hey! My first suggestion would be to not use the legacy behaviour by setting legacy = False when you initialize the tokenizer.
Second, the txt_encoded == txt_encoded_decoded assumption is not always true for all tokenizers. In this case, the decoding adds an extra space, maybe because it is based on the previous legacy behaviour. Will investigate

@g588928812
Copy link
Author

My first suggestion would be to not use the legacy behaviour by setting legacy = False when you initialize the tokenizer.

thanks! I tried that though and it did not change the output

@ArthurZucker
Copy link
Collaborator

Ok, the same issue exists with the fast version, but the problem is with the encoding that adds extra spaces between the special tokens.... It's a mess haha

@wlhgtc
Copy link
Contributor

wlhgtc commented Jul 28, 2023

@ArthurZucker
Sorry I can't understand when and why we need to set legacy=False , Could you exlpain?
I run the code as follows:

    txt = "one more thing" + "<s>" + "traditionally" + "<s>"
    tokenizer1 = LlamaTokenizer.from_pretrained(
        "./resources/models/llama-2-7b-hf", legacy=True, use_fast=False
    )
    tokenizer2 = LlamaTokenizer.from_pretrained(
        "./resources/models/llama-2-7b-hf", legacy=False, use_fast=False
    )

    t1 = tokenizer1.tokenize(txt)
    t2 = tokenizer2.tokenize(txt)

Then I got:

t1:['▁one', '▁more', '▁thing', '<s>', '▁tradition', 'ally', '<s>']
t2:['▁one', '▁more', '▁thing', '<s>', 'tradition', 'ally', '<s>']

The word starting with a usually means the start of a new word (as when comparing more and ally).
Even though we don't add a space before "traditionally", it is still considered a new word.
So, seems tokenizer2 is meaningful?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jul 28, 2023

No, words starting with _ means that these word have a space before them, and thus the token is _tradition. While tradition is a different token. If you read the documentation that points to the PR #24565, there is a similar example.
What's important to understand is the concept of added tokens.

Most often, sentencepiece tokenizers have a vocabulary, but some tokens are added afterwards. This happens with t5 for example. In transformers, we do not modify the underlying sentencepiece object. But we still support adding tokens.

Now imagine if thin is part of the sentencpiece vocab, but not _thin. If thin appears next to a work like thinking, is will be tokenized as [_, thin, king], not [_, thin, _king]. The same applies for any tokens that are originally part of the sentencepiece model.

In transformers all special tokens are kind of added to the vocabulary, so we want to reproduce the behaviour and not add extra space.

PS: please refrain from asking something pretty much unrelated. If you have a question (not a bug) feel free to post it on the discussion forum

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Sep 2, 2023
@xenova
Copy link
Contributor

xenova commented Dec 17, 2023

@ArthurZucker this should be reopened, right? As stated in your previous response:

In transformers all special tokens are kind of added to the vocabulary, so we want to reproduce the behaviour and not add extra space.

So, basically, there should not be added space after special tokens... However, I'm getting the opposite results to this, with legacy=False being incorrect.

from transformers import AutoTokenizer
text = "hello world"

# 1. Legacy tokenizer
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", use_fast=False, legacy=True)
token_ids = tokenizer.encode(text, add_special_tokens=True)
print(f'{token_ids=}')                              # [1, 22172, 3186] (correct)
print(f'{tokenizer.decode(token_ids)=}')            # '<s>hello world' (correct)

# 2. Non-Legacy tokenizer
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", use_fast=False, legacy=False)
token_ids = tokenizer.encode(text, add_special_tokens=True)
print(f'{token_ids=}')                              # [1, 22172, 3186]  (correct)
print(f'{tokenizer.decode(token_ids)=}')            # '<s> hello world' (incorrect)

(this is also different to the other related issues, since those deals with encoding and not decoding)

@xenova xenova reopened this Dec 17, 2023
@ArthurZucker
Copy link
Collaborator

Yes, until #26678 is merged

@huggingface huggingface deleted a comment from github-actions bot Jan 12, 2024
@ArthurZucker
Copy link
Collaborator

Just wait a bit!

Copy link

github-actions bot commented Feb 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@xenova
Copy link
Contributor

xenova commented Feb 6, 2024

Closed by #26678 👍

@xenova xenova closed this as completed Feb 6, 2024
@Butanium
Copy link

Butanium commented Aug 14, 2024

Why is this closed while depending on if you set use_fast to True or False, the behavior is not the same @ArthurZucker ?

@ArthurZucker
Copy link
Collaborator

No the reason it's closed is because this has a flag: legacy, which can be set to True or False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants