LayoutLMv2Processor: ensure 1-to-1 mapping between images and samples in case of overflowing tokens (#17092)
* add get_overflowing_images function to ensure 1-to-1 mapping between samples and images in LayoutLMv2Processor * make style * add test for overflowing_tokens, change assert to ValueError, avoiding unrelated formatting changes * change line length by passing --preview into black
This commit is contained in:
@@ -86,10 +86,12 @@ class LayoutLMv2Processor(ProcessorMixin):
|
|||||||
|
|
||||||
if self.feature_extractor.apply_ocr and (word_labels is not None):
|
if self.feature_extractor.apply_ocr and (word_labels is not None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot provide word labels "
|
"You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
|
||||||
"if you initialized the feature extractor with apply_ocr set to True."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_overflowing_tokens is True and return_offsets_mapping is False:
|
||||||
|
raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")
|
||||||
|
|
||||||
# first, apply the feature extractor
|
# first, apply the feature extractor
|
||||||
features = self.feature_extractor(images=images, return_tensors=return_tensors)
|
features = self.feature_extractor(images=images, return_tensors=return_tensors)
|
||||||
|
|
||||||
@@ -122,6 +124,23 @@ class LayoutLMv2Processor(ProcessorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# add pixel values
|
# add pixel values
|
||||||
encoded_inputs["image"] = features.pop("pixel_values")
|
images = features.pop("pixel_values")
|
||||||
|
if return_overflowing_tokens is True:
|
||||||
|
images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
|
||||||
|
encoded_inputs["image"] = images
|
||||||
|
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|
||||||
|
def get_overflowing_images(self, images, overflow_to_sample_mapping):
|
||||||
|
# in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
|
||||||
|
images_with_overflow = []
|
||||||
|
for sample_idx in overflow_to_sample_mapping:
|
||||||
|
images_with_overflow.append(images[sample_idx])
|
||||||
|
|
||||||
|
if len(images_with_overflow) != len(overflow_to_sample_mapping):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
|
||||||
|
f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return images_with_overflow
|
||||||
|
|||||||
@@ -133,6 +133,39 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
|
|||||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_overflowing_tokens(self):
|
||||||
|
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
# set up
|
||||||
|
datasets = load_dataset("nielsr/funsd")
|
||||||
|
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
|
||||||
|
|
||||||
|
def preprocess_data(examples):
|
||||||
|
images = [Image.open(path).convert("RGB") for path in examples["image_path"]]
|
||||||
|
words = examples["words"]
|
||||||
|
boxes = examples["bboxes"]
|
||||||
|
word_labels = examples["ner_tags"]
|
||||||
|
encoded_inputs = processor(
|
||||||
|
images,
|
||||||
|
words,
|
||||||
|
boxes=boxes,
|
||||||
|
word_labels=word_labels,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
return_overflowing_tokens=True,
|
||||||
|
stride=50,
|
||||||
|
return_offsets_mapping=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
return encoded_inputs
|
||||||
|
|
||||||
|
train_data = preprocess_data(datasets["train"])
|
||||||
|
|
||||||
|
self.assertEqual(len(train_data["image"]), len(train_data["input_ids"]))
|
||||||
|
|
||||||
|
|
||||||
# different use cases tests
|
# different use cases tests
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user