Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LayoutLMForQuestionAnswering model #18407

Merged
merged 20 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/en/model_doc/layoutlm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ This model was contributed by [liminghao1630](https://huggingface.co/liminghao16

[[autodoc]] LayoutLMForTokenClassification

## LayoutLMForQuestionAnswering

[[autodoc]] LayoutLMForQuestionAnswering

## TFLayoutLMModel

[[autodoc]] TFLayoutLMModel
Expand All @@ -122,3 +126,7 @@ This model was contributed by [liminghao1630](https://huggingface.co/liminghao16
## TFLayoutLMForTokenClassification

[[autodoc]] TFLayoutLMForTokenClassification

## TFLayoutLMForQuestionAnswering

[[autodoc]] TFLayoutLMForQuestionAnswering
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,7 @@
"LayoutLMForMaskedLM",
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMForQuestionAnswering",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
Expand Down Expand Up @@ -2337,6 +2338,7 @@
"TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFLayoutLMForMaskedLM",
"TFLayoutLMForSequenceClassification",
"TFLayoutLMForQuestionAnswering",
"TFLayoutLMForTokenClassification",
"TFLayoutLMMainLayer",
"TFLayoutLMModel",
Expand Down Expand Up @@ -3945,6 +3947,7 @@
from .models.layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
LayoutLMForQuestionAnswering,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
Expand Down Expand Up @@ -4583,6 +4586,7 @@
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMMainLayer,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/layoutlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"LayoutLMForMaskedLM",
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMForQuestionAnswering",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
Expand All @@ -66,6 +67,7 @@
"TFLayoutLMForMaskedLM",
"TFLayoutLMForSequenceClassification",
"TFLayoutLMForTokenClassification",
"TFLayoutLMForQuestionAnswering",
"TFLayoutLMMainLayer",
"TFLayoutLMModel",
"TFLayoutLMPreTrainedModel",
Expand Down Expand Up @@ -93,6 +95,7 @@
from .modeling_layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,
LayoutLMForQuestionAnswering,
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
Expand All @@ -107,6 +110,7 @@
from .modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
TFLayoutLMForQuestionAnswering,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMMainLayer,
Expand Down
162 changes: 153 additions & 9 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
Expand All @@ -40,7 +41,6 @@
logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LayoutLMConfig"
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/layoutlm-base-uncased"

LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -749,10 +749,10 @@ def forward(
Examples:

```python
>>> from transformers import LayoutLMTokenizer, LayoutLMModel
>>> from transformers import AutoTokenizer, LayoutLMModel
>>> import torch

>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")

>>> words = ["Hello", "world"]
Expand Down Expand Up @@ -896,10 +896,10 @@ def forward(
Examples:

```python
>>> from transformers import LayoutLMTokenizer, LayoutLMForMaskedLM
>>> from transformers import AutoTokenizer, LayoutLMForMaskedLM
>>> import torch

>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")

>>> words = ["Hello", "[MASK]"]
Expand Down Expand Up @@ -1018,10 +1018,10 @@ def forward(
Examples:

```python
>>> from transformers import LayoutLMTokenizer, LayoutLMForSequenceClassification
>>> from transformers import AutoTokenizer, LayoutLMForSequenceClassification
>>> import torch

>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")

>>> words = ["Hello", "world"]
Expand Down Expand Up @@ -1153,10 +1153,10 @@ def forward(
Examples:

```python
>>> from transformers import LayoutLMTokenizer, LayoutLMForTokenClassification
>>> from transformers import AutoTokenizer, LayoutLMForTokenClassification
>>> import torch

>>> tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
>>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")

>>> words = ["Hello", "world"]
Expand Down Expand Up @@ -1222,3 +1222,147 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


@add_start_docstrings(
"""
LayoutLM Model with a span classification head on top for extractive question-answering tasks such as
[DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to compute `span
start logits` and `span end logits`).
""",
LAYOUTLM_START_DOCSTRING,
)
class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
def __init__(self, config, has_visual_segment_embedding=True):
super().__init__(config)
self.num_labels = config.num_labels

self.layoutlm = LayoutLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.layoutlm.embeddings.word_embeddings

@replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
bbox: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.

Returns:

Example:

In the example below, we prepare a question + context pair for the LayoutLM model. It will give us a prediction
of what it thinks the answer is (the span of the answer within the texts parsed from the image).

```python
>>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
>>> from datasets import load_dataset
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
>>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa")

>>> dataset = load_dataset("nielsr/funsd", split="train")
>>> example = dataset[0]
>>> question = "what's his name?"
>>> words = example["words"]
>>> boxes = example["bboxes"]

>>> encoding = tokenizer(
... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
... )
>>> bbox = []
>>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
... if s == 1:
... bbox.append(boxes[w])
... elif i == tokenizer.sep_token_id:
... bbox.append([1000] * 4)
... else:
... bbox.append([0] * 4)
>>> encoding["bbox"] = torch.tensor([bbox])

>>> word_ids = encoding.word_ids(0)
>>> outputs = model(**encoding)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm the code examples work as expected, it would be great to add LayoutLM (v1) to the doc tests. Details here: https://github.com/huggingface/transformers/tree/main/docs#testing-documentation-examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

>>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
>>> print(" ".join(words[start : end + 1]))
M. Hamann P. Harper, P. Martinez
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great example! Would like to see something similar for the other models (if we don't have any fine-tuned ones, then feel free to just verify an expected shape)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we address this separately? I'm not extremely familiar with all of the different models, so I'd prefer to separate it from this particular effort.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense.

```"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()

total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading