From 9336086ab5d232cccd9512333518cf4299528882 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 28 Aug 2020 11:15:17 -0400 Subject: [PATCH] prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654) * broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc --- examples/seq2seq/README.md | 6 +- examples/seq2seq/distillation.py | 118 +++++++++++------- examples/seq2seq/finetune.py | 44 ++++--- examples/seq2seq/run_eval.py | 2 + examples/seq2seq/test_seq2seq_examples.py | 38 +++--- examples/seq2seq/utils.py | 51 ++++---- src/transformers/modeling_bart.py | 3 + src/transformers/modeling_t5.py | 2 +- src/transformers/tokenization_bart.py | 19 ++- src/transformers/tokenization_marian.py | 4 +- src/transformers/tokenization_mbart.py | 65 +++++----- src/transformers/tokenization_pegasus.py | 9 +- src/transformers/tokenization_t5.py | 9 +- tests/test_modeling_bart.py | 82 +----------- tests/test_tokenization_bart.py | 145 ++++++++++++++++++++++ tests/test_tokenization_common.py | 11 +- tests/test_tokenization_mbart.py | 88 ++++++++----- tests/test_tokenization_pegasus.py | 7 +- tests/test_tokenization_roberta.py | 8 +- tests/test_tokenization_t5.py | 8 +- 20 files changed, 429 insertions(+), 290 deletions(-) create mode 100644 tests/test_tokenization_bart.py diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 9c94b0c2c8..db047fe956 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -71,8 +71,8 @@ Summarization Tips: (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). **Update 2018-07-18** -Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_seq2seq_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.** -A new dataset is needed to support multilingual tasks. +Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used. +Future work/help wanted: A new dataset to support multilingual tasks. ### Command Line Options @@ -106,7 +106,7 @@ The following command should work on a 16GB GPU: --train_batch_size=1 \ --eval_batch_size=1 \ --output_dir=xsum_results \ - --num_train_epochs 1 \ + --num_train_epochs 6 \ --model_name_or_path facebook/bart-large ``` diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 67e695ef99..262fae182f 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -1,6 +1,7 @@ import argparse import gc import os +import warnings from pathlib import Path from typing import List @@ -11,6 +12,7 @@ from torch.nn import functional as F from lightning_base import generic_train from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration +from transformers.modeling_bart import shift_tokens_right try: @@ -22,6 +24,7 @@ try: assert_all_frozen, calculate_bleu, freeze_params, + label_smoothed_nll_loss, pickle_load, use_task_specific_params, ) @@ -34,12 +37,15 @@ except ImportError: assert_all_frozen, calculate_bleu, freeze_params, + label_smoothed_nll_loss, pickle_load, use_task_specific_params, ) class BartSummarizationDistiller(SummarizationModule): + """Supports Bart, Pegasus and other models that inherit from Bart.""" + loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] def __init__(self, hparams): @@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule): def _step(self, batch): # assert is_frozen(self.teacher) pad_token_id = self.tokenizer.pad_token_id - input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] - decoder_input_ids = y[:, :-1].contiguous() - labels = y[:, 1:].clone() - labels[y[:, 1:] == pad_token_id] = -100 + input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"] + decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) # noinspection PyCallingNonCallable - sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self( + lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, - labels=labels, output_hidden_states=True, output_attentions=False, - ) + use_cache=False, + ) # TODO(@sshleifer): return_dict=True cleanup + + # Same cross entropy vs. label smoothing logic as finetune.py + assert lm_logits.shape[-1] == self.model.config.vocab_size + if self.hparams.label_smoothing == 0: + # Same behavior as modeling_bart.py, besides ignoring pad_token_id + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) + student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) + else: + lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) + student_lm_loss, _ = label_smoothed_nll_loss( + lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id + ) def zero_tensor(): - return torch.tensor(0.0).type_as(sloss) + return torch.tensor(0.0).type_as(student_lm_loss) loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() if self.different_encoder: @@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule): attention_mask=src_mask, encoder_outputs=teacher_enc_outputs, decoder_input_ids=decoder_input_ids, - lm_labels=labels, + lm_labels=tgt_ids, output_hidden_states=True, ) dec_mask = decoder_input_ids.ne(pad_token_id) - loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits) + loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy) blended_loss = ( self.alpha_ce * loss_ce - + self.alpha_mlm * sloss + + self.alpha_mlm * student_lm_loss + self.hparams.alpha_encoder_loss * loss_encoder + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) ) - return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec + return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches): - assert not isinstance( - hidden_states, torch.Tensor - ), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}" - assert not isinstance( - hidden_states_T, torch.Tensor - ), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}" + msg = "expected list or tuple for hidden_states, got tensor of shape: " + assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}" + assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" mask = attention_mask.to(hidden_states[0]) valid_count = mask.sum() * hidden_states[0].size(-1) hidden_losses = [ @@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule): def add_distill_args(parser): - parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) + parser.add_argument("--teacher", type=str) parser.add_argument("--alpha_ce", default=0.8, type=float) parser.add_argument("--alpha_mlm", default=0.2, type=float) parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) @@ -245,8 +258,9 @@ def add_distill_args(parser): class BartTranslationDistiller(BartSummarizationDistiller): + """Supports Mbart, Marian, other models that inherit from Bart.""" + mode = "translation" - loss_names = ["loss"] metric_names = ["bleu"] val_metric = "bleu" @@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller): attention_mask=source_mask, encoder_outputs=teacher_enc_outputs, decoder_input_ids=decoder_input_ids, - lm_labels=labels, + labels=labels, output_hidden_states=True, use_cache=False, ) @@ -402,6 +416,7 @@ def create_module(args): def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): + # TODO(SS): DELETE? exp_dir = ckpt_path.parent if dest_dir is None: dest_dir = exp_dir @@ -424,33 +439,40 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): trainer.test(model) -def get_layers_to_copy(n_to_get, tot): - all_layers = list(range(tot)) - if tot == 12: # Alternating for special cases - layers_to_copy = { # maps num layers in student -> which teacher layers to copy - 1: [0], - 2: [0, 6], - 3: [0, 6, 11], - 4: [0, 4, 8, 11], - 6: [0, 2, 4, 7, 9, 11], - 9: [0, 1, 2, 4, 5, 7, 9, 10, 11], - 12: all_layers, - } - return layers_to_copy[n_to_get] - elif tot == 16: - layers_to_copy = { # maps num layers in student -> which teacher layers to copy - 1: [0], - 2: [0, 8], - 3: [0, 8, 15], - 4: [0, 5, 10, 15], - 6: [0, 3, 6, 9, 12, 15], - 8: [0, 2, 4, 6, 8, 10, 12, 15], - 9: [0, 1, 3, 5, 7, 9, 11, 13, 15], - 16: all_layers, - } - return layers_to_copy[n_to_get] - else: - return all_layers[:n_to_get] # TODO: better version on theseus-bart branch +LAYERS_TO_COPY = { + # maps num layers in student -> which teacher layers to copy. + # 12: bart, 16: pegasus, 6: marian/Helsinki-NLP + 12: { + 1: [0], + 2: [0, 6], + 3: [0, 6, 11], + 4: [0, 4, 8, 11], + 6: [0, 2, 4, 7, 9, 11], + 9: [0, 1, 2, 4, 5, 7, 9, 10, 11], + 12: list(range(12)), + }, + 16: { # maps num layers in student -> which teacher layers to copy + 1: [0], + 2: [0, 8], + 3: [0, 8, 15], + 4: [0, 5, 10, 15], + 6: [0, 3, 6, 9, 12, 15], + 8: [0, 2, 4, 6, 8, 10, 12, 15], + 9: [0, 1, 3, 5, 7, 9, 11, 13, 15], + 16: list(range(16)), + }, + 6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))}, +} + + +def get_layers_to_copy(n_student, n_teacher): + try: + return LAYERS_TO_COPY[n_teacher][n_student] + except KeyError: + warnings.warn( + f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}" + ) + return list(range(n_student)) def distill_main(args): diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 539b296142..90591e1b0c 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -13,15 +13,16 @@ import torch from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train -from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration +from transformers import MBartTokenizer, T5ForConditionalGeneration +from transformers.modeling_bart import shift_tokens_right try: from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from .utils import ( ROUGE_KEYS, + LegacySeq2SeqDataset, Seq2SeqDataset, - TranslationDataset, assert_all_frozen, calculate_bleu, calculate_rouge, @@ -39,8 +40,8 @@ except ImportError: from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from utils import ( ROUGE_KEYS, + LegacySeq2SeqDataset, Seq2SeqDataset, - TranslationDataset, assert_all_frozen, calculate_bleu, calculate_rouge, @@ -102,14 +103,13 @@ class SummarizationModule(BaseTransformer): self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers - self.decoder_start_token_id = None + self.decoder_start_token_id = None # default to config if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] self.model.config.decoder_start_token_id = self.decoder_start_token_id - if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer): - self.dataset_class = TranslationDataset - else: - self.dataset_class = Seq2SeqDataset + self.dataset_class = ( + Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset + ) def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -134,27 +134,25 @@ class SummarizationModule(BaseTransformer): def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] - + src_ids, src_mask = batch["input_ids"], batch["attention_mask"] + tgt_ids = batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): - decoder_input_ids = self.model._shift_right(target_ids) - lm_labels = target_ids + decoder_input_ids = self.model._shift_right(tgt_ids) else: - decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? - lm_labels = target_ids[:, 1:].clone() # why clone? - - outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False) + decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) + outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) + lm_logits = outputs[0] if self.hparams.label_smoothing == 0: - # Same behavior as modeling_bart.py + # Same behavior as modeling_bart.py, besides ignoring pad_token_id loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) - lm_logits = outputs[0] + assert lm_logits.shape[-1] == self.model.config.vocab_size - loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1)) + loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) else: - lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) + lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) loss, nll_loss = label_smoothed_nll_loss( - lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id + lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id ) return (loss,) @@ -167,7 +165,7 @@ class SummarizationModule(BaseTransformer): logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} # tokens per batch - logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum() + logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum() return {"loss": loss_tensors[0], "log": logs} def validation_step(self, batch, batch_idx) -> Dict: @@ -204,7 +202,7 @@ class SummarizationModule(BaseTransformer): ) gen_time = (time.time() - t0) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) - target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) + target: List[str] = self.ids_to_clean_text(batch["labels"]) loss_tensors = self._step(batch) base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} rouge: Dict = self.calc_generative_metrics(preds, target) diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index c83b17608f..ba5150d1d5 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -132,4 +132,6 @@ def run_generate(): if __name__ == "__main__": + # Usage for MT: + # python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@ run_generate() diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index f853557f18..3747c0ac7f 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -10,18 +10,18 @@ from unittest.mock import patch import pytest import pytorch_lightning as pl import torch -from pytest import param from torch.utils.data import DataLoader import lightning_base from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.modeling_bart import shift_tokens_right from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu from .distillation import distill_main, evaluate_checkpoint from .finetune import SummarizationModule, main from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate -from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json +from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json logging.basicConfig(level=logging.DEBUG) @@ -452,18 +452,27 @@ def test_pack_dataset(): assert orig_paths == new_paths -@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]) -def test_mbart_dataset_truncation(tok_name): +@pytest.mark.parametrize( + ["tok_name"], + [ + pytest.param(MBART_TINY), + pytest.param(MARIAN_TINY), + pytest.param(T5_TINY), + pytest.param(BART_TINY), + pytest.param("google/pegasus-xsum"), + ], +) +def test_seq2seq_dataset_truncation(tok_name): tokenizer = AutoTokenizer.from_pretrained(tok_name) tmp_dir = make_test_data_dir() max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_src_len = 4 max_tgt_len = 8 - assert max_len_target > max_src_len # Truncated - assert max_len_source > max_src_len - src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON - train_dataset = TranslationDataset( + assert max_len_target > max_src_len # Will be truncated + assert max_len_source > max_src_len # Will be truncated + src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error. + train_dataset = Seq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", @@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name): # show that articles were trimmed. assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len - assert batch["decoder_input_ids"].shape[1] == max_tgt_len - if tok_name == MARIAN_TINY: + assert batch["labels"].shape[1] == max_tgt_len + if tok_name != MBART_TINY: continue # check language codes in correct place + batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id) assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id @@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name): break # No need to test every batch -@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)]) -def test_summarization_dataset_truncation(tok): +@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")]) +def test_legacy_dataset_truncation(tok): tokenizer = AutoTokenizer.from_pretrained(tok) tmp_dir = make_test_data_dir() max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) trunc_target = 4 - train_dataset = Seq2SeqDataset( + train_dataset = LegacySeq2SeqDataset( tokenizer, data_dir=tmp_dir, type_path="train", @@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok): assert batch["input_ids"].shape[1] == max_len_source assert 20 >= batch["input_ids"].shape[1] # trimmed significantly # show that targets were truncated - assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated + assert batch["labels"].shape[1] == trunc_target # Truncated assert max_len_target > trunc_target # Truncated break # No need to test every batch diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 48375c6854..2cee416574 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -3,7 +3,6 @@ import json import linecache import os import pickle -import warnings from logging import getLogger from pathlib import Path from typing import Callable, Dict, Iterable, List @@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): + """Only used by LegacyDataset""" extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} return tokenizer( [line], @@ -75,7 +75,7 @@ def trim_batch( return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) -class Seq2SeqDataset(Dataset): +class AbstractSeq2SeqDataset(Dataset): def __init__( self, tokenizer, @@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset): self.pad_token_id = self.tokenizer.pad_token_id self.src_lang = src_lang self.tgt_lang = tgt_lang + self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer) def __len__(self): return len(self.src_lens) + @staticmethod + def get_char_lens(data_file): + return [len(x) for x in Path(data_file).open().readlines()] + + def make_sortish_sampler(self, batch_size): + return SortishSampler(self.src_lens, batch_size) + + def __getitem__(self, item): + raise NotImplementedError("You must implement this") + + def collate_fn(self, batch): + raise NotImplementedError("You must implement this") + + +class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): def __getitem__(self, index) -> Dict[str, torch.Tensor]: + """Call tokenizer on src and tgt_lines""" index = index + 1 # linecache starts at 1 source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") @@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset): return { "input_ids": source_ids, "attention_mask": src_mask, - "decoder_input_ids": target_ids, + "labels": target_ids, } - @staticmethod - def get_char_lens(data_file): - return [len(x) for x in Path(data_file).open().readlines()] - def collate_fn(self, batch) -> Dict[str, torch.Tensor]: input_ids = torch.stack([x["input_ids"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch]) - target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) + target_ids = torch.stack([x["labels"] for x in batch]) pad_token_id = self.pad_token_id y = trim_batch(target_ids, pad_token_id) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) batch = { "input_ids": source_ids, "attention_mask": source_mask, - "decoder_input_ids": y, + "labels": y, } return batch - def make_sortish_sampler(self, batch_size): - return SortishSampler(self.src_lens, batch_size) - -class TranslationDataset(Seq2SeqDataset): +class Seq2SeqDataset(AbstractSeq2SeqDataset): """A dataset that calls prepare_seq2seq_batch.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.max_source_length != self.max_target_length: - warnings.warn( - f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. " - f"Imbalanced sequence lengths may be undesired for translation tasks" - ) - def __getitem__(self, index) -> Dict[str, str]: index = index + 1 # linecache starts at 1 source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") @@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset): } def collate_fn(self, batch) -> Dict[str, torch.Tensor]: + """Call prepare_seq2seq_batch.""" batch_encoding = self.tokenizer.prepare_seq2seq_batch( [x["src_texts"] for x in batch], src_lang=self.src_lang, @@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset): tgt_lang=self.tgt_lang, max_length=self.max_source_length, max_target_length=self.max_target_length, + return_tensors="pt", + add_prefix_space=self.add_prefix_space, ) return batch_encoding.data @@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} +# Utilities for freezing parameters and checking whether they are frozen + + def freeze_params(model: nn.Module): + """Set requires_grad=False for each of model.parameters()""" for par in model.parameters(): par.requires_grad = False diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 6220ceed25..b4ee37f55e 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -151,6 +151,9 @@ def _prepare_bart_decoder_inputs( decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) else: decoder_padding_mask = invert_mask(decoder_padding_mask) + if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1: + # never mask leading token, even if it is pad + decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1] causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( dtype=causal_mask_dtype, device=decoder_input_ids.device ) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 6e5d3c4c83..a8ae72d0b2 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -636,7 +636,7 @@ class T5PreTrainedModel(PreTrainedModel): # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100" + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" return shifted_input_ids diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 8ee85f7fac..030917c3c3 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -33,6 +33,7 @@ _all_bart_models = [ "facebook/bart-large-cnn", "facebook/bart-large-xsum", "yjernite/bart_eli5", + # This is not exhaustive: see https://huggingface.co/models?filter=bart ] @@ -117,6 +118,8 @@ class BartTokenizer(RobertaTokenizer): The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. """ + kwargs.pop("src_lang", None) + kwargs.pop("tgt_lang", None) if max_length is None: max_length = self.model_max_length model_inputs: BatchEncoding = self( @@ -133,7 +136,7 @@ class BartTokenizer(RobertaTokenizer): # Process tgt_texts if max_target_length is None: max_target_length = max_length - decoder_inputs: BatchEncoding = self( + labels = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, @@ -141,10 +144,8 @@ class BartTokenizer(RobertaTokenizer): max_length=max_target_length, truncation=truncation, **kwargs, - ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - + )["input_ids"] + model_inputs["labels"] = labels return model_inputs @@ -245,7 +246,7 @@ class BartTokenizerFast(RobertaTokenizerFast): # Process tgt_texts if max_target_length is None: max_target_length = max_length - decoder_inputs: BatchEncoding = self( + labels = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, @@ -253,8 +254,6 @@ class BartTokenizerFast(RobertaTokenizerFast): max_length=max_target_length, truncation=truncation, **kwargs, - ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - + )["input_ids"] + model_inputs["labels"] = labels return model_inputs diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index f883d41288..3f06092b53 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -160,9 +160,7 @@ class MarianTokenizer(PreTrainedTokenizer): tokenizer_kwargs["max_length"] = max_target_length self.current_spm = self.spm_target - decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v + model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"] self.current_spm = self.spm_source return model_inputs diff --git a/src/transformers/tokenization_mbart.py b/src/transformers/tokenization_mbart.py index 9a4dffb725..d575c130ad 100644 --- a/src/transformers/tokenization_mbart.py +++ b/src/transformers/tokenization_mbart.py @@ -98,32 +98,6 @@ class MBartTokenizer(XLMRobertaTokenizer): self._additional_special_tokens = list(self.lang_code_to_id.keys()) self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks - by concatenating and adding special tokens. The special tokens depend on calling set_lang. - An MBART sequence has the following format, where ``X`` represents the sequence: - - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` - - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` - BOS is never used. - Pairs of sequences are not the expected use case, but they will be handled without a separator. - - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs to which the special tokens will be added - token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): - Optional second list of IDs for sequence pairs. - - Returns: - :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. - """ - if token_ids_1 is None: - return self.prefix_tokens + token_ids_0 + self.suffix_tokens - # We don't expect to process pairs, but leave the pair logic for API consistency - return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: @@ -156,6 +130,32 @@ class MBartTokenizer(XLMRobertaTokenizer): return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. The special tokens depend on calling set_lang. + An MBART sequence has the following format, where ``X`` represents the sequence: + - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` + - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` + BOS is never used. + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, @@ -251,7 +251,8 @@ class MBartTokenizer(XLMRobertaTokenizer): if max_target_length is None: max_target_length = max_length self.set_tgt_lang_special_tokens(tgt_lang) - decoder_inputs: BatchEncoding = self( + + labels = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, @@ -259,10 +260,8 @@ class MBartTokenizer(XLMRobertaTokenizer): max_length=max_target_length, truncation=True, **kwargs, - ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - + )["input_ids"] + model_inputs["labels"] = labels self.set_src_lang_special_tokens(src_lang) # sets to src_lang return model_inputs @@ -275,5 +274,5 @@ class MBartTokenizer(XLMRobertaTokenizer): def set_tgt_lang_special_tokens(self, lang: str) -> None: """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" self.cur_lang_code = self.lang_code_to_id[lang] - self.prefix_tokens = [self.cur_lang_code] - self.suffix_tokens = [self.eos_token_id] + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] diff --git a/src/transformers/tokenization_pegasus.py b/src/transformers/tokenization_pegasus.py index e553ad456d..369fcd4673 100644 --- a/src/transformers/tokenization_pegasus.py +++ b/src/transformers/tokenization_pegasus.py @@ -114,6 +114,7 @@ class PegasusTokenizer(ReformerTokenizer): return_tensors: str = "pt", truncation=True, padding="longest", + **unused, ) -> BatchEncoding: """ Prepare model inputs for summarization or translation. @@ -133,7 +134,9 @@ class PegasusTokenizer(ReformerTokenizer): return model_inputs if max_target_length is not None: tokenizer_kwargs["max_length"] = max_target_length - decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v + # TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id + labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"] + model_inputs["labels"] = labels + # for k, v in decoder_inputs.items(): + # model_inputs[f"decoder_{k}"] = v return model_inputs diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index ce686612c0..0c9966a48e 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -346,7 +346,7 @@ class T5Tokenizer(PreTrainedTokenizer): if max_length is None: max_length = self.max_len self.prefix_tokens = [] - model_inputs: BatchEncoding = self( + model_inputs = self( src_texts, add_special_tokens=True, return_tensors=return_tensors, @@ -362,7 +362,7 @@ class T5Tokenizer(PreTrainedTokenizer): max_target_length = max_length # set prefix_tokens for target text self.prefix_tokens = [self.pad_token_id] - decoder_inputs: BatchEncoding = self( + labels_and_decoder_mask = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, @@ -371,8 +371,7 @@ class T5Tokenizer(PreTrainedTokenizer): truncation=truncation, **kwargs, ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - + model_inputs["labels"] = labels_and_decoder_mask["input_ids"] + model_inputs["decoder_attention_mask"] = labels_and_decoder_mask["attention_mask"] self.prefix_tokens = [] return model_inputs diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 74306556a6..816796a911 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -18,7 +18,7 @@ import unittest import timeout_decorator # noqa -from transformers import BatchEncoding, is_torch_available +from transformers import is_torch_available from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device @@ -496,7 +496,7 @@ class BartModelIntegrationTests(unittest.TestCase): def test_xsum_summarization_same_as_fairseq(self): model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device) self.assertFalse(model.config.is_valid_mbart()) - tok = BartTokenizer.from_pretrained("facebook/bart-large") + tok = self.default_tokenizer EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state." dct = tok.batch_encode_plus( @@ -585,84 +585,6 @@ class BartModelIntegrationTests(unittest.TestCase): # TODO(SS): run fairseq again with num_beams=2, min_len=20. # TODO(SS): add test case that hits max_length - def test_prepare_seq2seq_batch(self): - tokenizers = [self.default_tokenizer, self.default_tokenizer_fast] - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] - expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2] - - for tokenizer in tokenizers: - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt" - ) - self.assertIsInstance(batch, BatchEncoding) - - self.assertEqual((2, 10), batch.input_ids.shape) - self.assertEqual((2, 10), batch.attention_mask.shape) - result = batch.input_ids.tolist()[0] - self.assertListEqual(expected_src_tokens, result) - # Test that special tokens are reset - - def test_empty_target_text(self): - tokenizers = [self.default_tokenizer, self.default_tokenizer_fast] - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] - for tokenizer in tokenizers: - batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt") - # check if input_ids are returned and no decoder_input_ids - self.assertIn("input_ids", batch) - self.assertIn("attention_mask", batch) - self.assertNotIn("decoder_input_ids", batch) - self.assertNotIn("decoder_attention_mask", batch) - - def test_max_target_length(self): - tokenizers = [self.default_tokenizer, self.default_tokenizer_fast] - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] - tgt_text = [ - "Summary of the text.", - "Another summary.", - ] - for tokenizer in tokenizers: - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt" - ) - self.assertEqual(32, batch["decoder_input_ids"].shape[1]) - self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) - - # test None max_target_length - batch = tokenizer.prepare_seq2seq_batch( - src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt" - ) - self.assertEqual(32, batch["decoder_input_ids"].shape[1]) - self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) - - def test_outputs_not_longer_than_maxlen(self): - tokenizers = [self.default_tokenizer, self.default_tokenizer_fast] - - for tokenizer in tokenizers: - batch = tokenizer.prepare_seq2seq_batch( - ["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt" - ) - self.assertIsInstance(batch, BatchEncoding) - self.assertEqual(batch.input_ids.shape, (2, 1024)) - - def test_special_tokens(self): - tokenizers = [self.default_tokenizer, self.default_tokenizer_fast] - src_text = ["A long paragraph for summrization."] - tgt_text = [ - "Summary of the text.", - ] - for tokenizer in tokenizers: - batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt") - input_ids = batch["input_ids"] - decoder_input_ids = batch["decoder_input_ids"] - self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item()) - self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item()) - self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item()) - self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item()) - @require_torch class TestSinusoidalPositionalEmbeddings(unittest.TestCase): diff --git a/tests/test_tokenization_bart.py b/tests/test_tokenization_bart.py new file mode 100644 index 0000000000..59fe1786da --- /dev/null +++ b/tests/test_tokenization_bart.py @@ -0,0 +1,145 @@ +import json +import os +import unittest + +from transformers import BartTokenizer, BartTokenizerFast, BatchEncoding +from transformers.file_utils import cached_property +from transformers.testing_utils import require_torch +from transformers.tokenization_roberta import VOCAB_FILES_NAMES + +from .test_tokenization_common import TokenizerTesterMixin + + +class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = BartTokenizer + + def setUp(self): + super().setUp() + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + + def get_rust_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return BartTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + + def get_input_output_texts(self, tokenizer): + return "lower newer", "lower newer" + + @cached_property + def default_tokenizer(self): + return BartTokenizer.from_pretrained("facebook/bart-large") + + @cached_property + def default_tokenizer_fast(self): + return BartTokenizerFast.from_pretrained("facebook/bart-large") + + @require_torch + def test_prepare_seq2seq_batch(self): + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2] + + for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt" + ) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 10), batch.input_ids.shape) + self.assertEqual((2, 10), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(expected_src_tokens, result) + # Test that special tokens are reset + + # Test Prepare Seq + @require_torch + def test_seq2seq_batch_empty_target_text(self): + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: + batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt") + # check if input_ids are returned and no labels + self.assertIn("input_ids", batch) + self.assertIn("attention_mask", batch) + self.assertNotIn("labels", batch) + self.assertNotIn("decoder_attention_mask", batch) + + @require_torch + def test_seq2seq_batch_max_target_length(self): + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt" + ) + self.assertEqual(32, batch["labels"].shape[1]) + + # test None max_target_length + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt" + ) + self.assertEqual(32, batch["labels"].shape[1]) + + @require_torch + def test_seq2seq_batch_not_longer_than_maxlen(self): + for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: + batch = tokenizer.prepare_seq2seq_batch( + ["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt" + ) + self.assertIsInstance(batch, BatchEncoding) + self.assertEqual(batch.input_ids.shape, (2, 1024)) + + @require_torch + def test_special_tokens(self): + + src_text = ["A long paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + ] + for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]: + batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt") + input_ids = batch["input_ids"] + labels = batch["labels"] + self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item()) + self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item()) + self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item()) + self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item()) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 55c14bfacb..22e845d2c4 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1555,14 +1555,19 @@ class TokenizerTesterMixin: "vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.", ] batch = tokenizer.prepare_seq2seq_batch( - src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt" + src_texts=src_text, + tgt_texts=tgt_text, + max_length=3, + max_target_length=10, + return_tensors="pt", + src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error ) self.assertEqual(batch.input_ids.shape[1], 3) - self.assertEqual(batch.decoder_input_ids.shape[1], 10) + self.assertEqual(batch.labels.shape[1], 10) # max_target_length will default to max_length if not specified batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3) self.assertEqual(batch.input_ids.shape[1], 3) - self.assertEqual(batch.decoder_input_ids.shape[1], 3) + self.assertEqual(batch.labels.shape[1], 3) batch_encoder_only = tokenizer.prepare_seq2seq_batch( src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt" diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index bda0be3aec..e6f77d7514 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -1,13 +1,16 @@ import tempfile import unittest -from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer +from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available from transformers.testing_utils import require_torch from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE +if is_torch_available(): + from transformers.modeling_bart import shift_tokens_right + EN_CODE = 250004 RO_CODE = 250020 @@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004) self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020) - def test_enro_tokenizer_prepare_seq2seq_batch(self): - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, - tgt_texts=self.tgt_text, - max_length=len(self.expected_src_tokens), - ) - self.assertIsInstance(batch, BatchEncoding) - - self.assertEqual((2, 14), batch.input_ids.shape) - self.assertEqual((2, 14), batch.attention_mask.shape) - result = batch.input_ids.tolist()[0] - self.assertListEqual(self.expected_src_tokens, result) - self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS - # Test that special tokens are reset - self.assertEqual(self.tokenizer.prefix_tokens, []) - self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) - - def test_max_target_length(self): - - batch = self.tokenizer.prepare_seq2seq_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10 - ) - self.assertEqual(batch.input_ids.shape[1], 3) - self.assertEqual(batch.decoder_input_ids.shape[1], 10) - # max_target_length will default to max_length if not specified - batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) - self.assertEqual(batch.input_ids.shape[1], 3) - self.assertEqual(batch.decoder_input_ids.shape[1], 3) - def test_enro_tokenizer_batch_encode_plus(self): ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] self.assertListEqual(self.expected_src_tokens, ids) @@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase): assert isinstance(src_text[0], str) desired_max_length = 10 ids = self.tokenizer.prepare_seq2seq_batch( - src_text, return_tensors=None, max_length=desired_max_length + src_text, + return_tensors=None, + max_length=desired_max_length, ).input_ids[0] self.assertEqual(ids[-2], 2) self.assertEqual(ids[-1], EN_CODE) @@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.tokenizer.save_pretrained(tmpdirname) new_tok = MBartTokenizer.from_pretrained(tmpdirname) self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens) + + # prepare_seq2seq_batch tests below + + @require_torch + def test_batch_fairseq_parity(self): + batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( + self.src_text, tgt_texts=self.tgt_text, return_tensors="pt" + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + for k in batch: + batch[k] = batch[k].tolist() + # batch = {k: v.tolist() for k,v in batch.items()} + # fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4 + # batch.decoder_inputs_ids[0][0] == + assert batch.input_ids[1][-2:] == [2, EN_CODE] + assert batch.decoder_input_ids[1][0] == RO_CODE + assert batch.decoder_input_ids[1][-1] == 2 + assert batch.labels[1][-2:] == [2, RO_CODE] + + @require_torch + def test_enro_tokenizer_prepare_seq2seq_batch(self): + batch = self.tokenizer.prepare_seq2seq_batch( + self.src_text, + tgt_texts=self.tgt_text, + max_length=len(self.expected_src_tokens), + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 14), batch.input_ids.shape) + self.assertEqual((2, 14), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(self.expected_src_tokens, result) + self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS + # Test that special tokens are reset + self.assertEqual(self.tokenizer.prefix_tokens, []) + self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) + + def test_seq2seq_max_target_length(self): + batch = self.tokenizer.prepare_seq2seq_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10 + ) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 10) + # max_target_length will default to max_length if not specified + batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) + batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 3) diff --git a/tests/test_tokenization_pegasus.py b/tests/test_tokenization_pegasus.py index 30af7c5efa..88a0a1bed4 100644 --- a/tests/test_tokenization_pegasus.py +++ b/tests/test_tokenization_pegasus.py @@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5) assert batch.input_ids.shape == (2, 1024) assert batch.attention_mask.shape == (2, 1024) - assert "decoder_input_ids" in batch # because tgt_texts was specified - assert batch.decoder_input_ids.shape == (2, 5) - assert batch.decoder_attention_mask.shape == (2, 5) - assert len(batch) == 4 # no extra keys + assert "labels" in batch # because tgt_texts was specified + assert batch.labels.shape == (2, 5) + assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel diff --git a/tests/test_tokenization_roberta.py b/tests/test_tokenization_roberta.py index f2a0cc7424..cbe37f21f1 100644 --- a/tests/test_tokenization_roberta.py +++ b/tests/test_tokenization_roberta.py @@ -66,7 +66,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def get_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) - return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) + return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) def get_rust_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) @@ -78,7 +78,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): return input_text, output_text def test_full_tokenizer(self): - tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) + tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower newer" bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"] tokens = tokenizer.tokenize(text) # , add_prefix_space=True) @@ -99,7 +99,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): @slow def test_sequence_builders(self): - tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + tokenizer = self.tokenizer_class.from_pretrained("roberta-base") text = tokenizer.encode("sequence builders", add_special_tokens=False) text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) @@ -137,7 +137,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0] self.assertNotEqual(first_char, space_encoding) - # Testing spaces after special tokenss + # Testing spaces after special tokens mask = "" tokenizer.add_special_tokens( {"mask_token": AddedToken(mask, lstrip=True, rstrip=False)} diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index a974da8baf..16bf536b25 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_max_target_length(self): tokenizer = self.t5_base_tokenizer - src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."] tgt_text = [ "Summary of the text.", "Another summary.", @@ -161,14 +161,14 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): batch = tokenizer.prepare_seq2seq_batch( src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK ) - self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["labels"].shape[1]) self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) # test None max_target_length batch = tokenizer.prepare_seq2seq_batch( src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK ) - self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["labels"].shape[1]) self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) def test_outputs_not_longer_than_maxlen(self): @@ -190,7 +190,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) src_ids = list(batch.input_ids.numpy()[0]) - tgt_ids = list(batch.decoder_input_ids.numpy()[0]) + tgt_ids = list(batch.labels.numpy()[0]) self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_tgt_tokens, tgt_ids)