From 304aacac90ea6df8f3bfc2956a0ae6137f690bc0 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 26 Apr 2023 18:29:25 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20[`Pix2St?= =?UTF-8?q?ruct`]=20Attempts=20to=20fix=20training=20issues=20=F0=9F=9A=A8?= =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=20(#23004)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * multiple fixes - add `add_special_tokens` to `True` by default - remove label smoothing and labels masking * fix test --- src/transformers/models/pix2struct/modeling_pix2struct.py | 5 ++--- src/transformers/models/pix2struct/processing_pix2struct.py | 2 +- tests/models/pix2struct/test_processor_pix2struct.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index c2cd1a0a32..1d2062519f 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1554,10 +1554,9 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) - loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean", label_smoothing=0.1) - masked_labels = labels.masked_fill(labels == self.config.pad_token_id, -100) + loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") - loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), masked_labels.contiguous().view(-1)) + loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) if not return_dict: return tuple( diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index eaa9f0dc42..bc54e14604 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -49,7 +49,7 @@ class Pix2StructProcessor(ProcessorMixin): self, images=None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - add_special_tokens: bool = False, + add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, diff --git a/tests/models/pix2struct/test_processor_pix2struct.py b/tests/models/pix2struct/test_processor_pix2struct.py index 6a7387b6ff..318e6f301f 100644 --- a/tests/models/pix2struct/test_processor_pix2struct.py +++ b/tests/models/pix2struct/test_processor_pix2struct.py @@ -108,7 +108,7 @@ class Pix2StructProcessorTest(unittest.TestCase): encoded_processor = processor(text=input_str) - encoded_tok = tokenizer(input_str, return_token_type_ids=False, add_special_tokens=False) + encoded_tok = tokenizer(input_str, return_token_type_ids=False, add_special_tokens=True) for key in encoded_tok.keys(): self.assertListEqual(encoded_tok[key], encoded_processor[key])