diff --git a/docs/source/model_doc/bart.rst b/docs/source/model_doc/bart.rst index 46046ba78f..8088b1def1 100644 --- a/docs/source/model_doc/bart.rst +++ b/docs/source/model_doc/bart.rst @@ -39,6 +39,18 @@ BartTokenizer :members: +MBartTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartTokenizer + :members: build_inputs_with_special_tokens, prepare_translation_batch + +BartForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BartForConditionalGeneration + :members: generate, forward + BartModel ~~~~~~~~~~~~~ @@ -62,10 +74,3 @@ BartForQuestionAnswering :members: forward -BartForConditionalGeneration -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.BartForConditionalGeneration - :members: generate, forward - - diff --git a/examples/longform-qa/eli5_app.py b/examples/longform-qa/eli5_app.py index a7d75565ae..675ffbf63d 100644 --- a/examples/longform-qa/eli5_app.py +++ b/examples/longform-qa/eli5_app.py @@ -1,10 +1,10 @@ import faiss import nlp import numpy as np +import streamlit as st import torch from elasticsearch import Elasticsearch -import streamlit as st import transformers from eli5_utils import ( embed_questions_for_retrieval, diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 762dd5b4b9..fbeab97cbd 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -41,6 +41,28 @@ If you are using your own data, it must be formatted as one directory with 6 fil The `.source` files are the input, the `.target` files are the desired output. +### Tips and Tricks + +General Tips: +- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started. +- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below) +- `fp16_opt_level=O1` (the default works best). +- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. +Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. +- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. +- This warning can be safely ignored: + > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" +- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). +- Read scripts before you run them! + +Summarization Tips: +- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. +- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. +- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` +- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. +- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. +- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. +(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). ### Summarization Finetuning Run/modify `finetune.sh` @@ -58,25 +80,20 @@ The following command should work on a 16GB GPU: *Note*: The following tips mostly apply to summarization finetuning. -Tips: -- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. -- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started. -- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below) -- `fp16_opt_level=O1` (the default works best). -- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. -(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). -- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. -Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. -- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. -- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. -- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` -- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. -- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. -- This warning can be safely ignored: - > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" -- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). +### Translation Finetuning -#### Finetuning Outputs +First, follow the wmt_en_ro download instructions. +Then you can finetune mbart_cc25 on english-romanian with the following command. +**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it. +```bash +export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded +export BS=4 +export GAS=8 +./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/ +``` + + +### Finetuning Outputs As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine). Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour: diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index fd1d00ecf3..16b7335b2f 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -14,11 +14,12 @@ import torch from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train -from transformers import get_linear_schedule_with_warmup +from transformers import MBartTokenizer, get_linear_schedule_with_warmup try: from .utils import ( + assert_all_frozen, use_task_specific_params, SummarizationDataset, lmap, @@ -47,6 +48,7 @@ except ImportError: get_git_info, ROUGE_KEYS, calculate_bleu_score, + assert_all_frozen, ) from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback @@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer): if self.hparams.freeze_embeds: self.freeze_embeds() if self.hparams.freeze_encoder: - freeze_params(self.model.model.encoder) # TODO: this will break for t5 + freeze_params(self.model.get_encoder()) + assert_all_frozen(self.model.get_encoder()) + self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers + self.decoder_start_token_id = None def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer): pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) t0 = time.time() - generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) + generated_ids = self.model.generate( + input_ids=source_ids, + attention_mask=source_mask, + use_cache=True, + decoder_start_token_id=self.decoder_start_token_id, + ) gen_time = (time.time() - t0) / source_ids.shape[0] preds = self.ids_to_clean_text(generated_ids) target = self.ids_to_clean_text(y) @@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer): parser.add_argument( "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." ) + parser.add_argument("--src_lang", type=str, default="", required=False) + parser.add_argument("--tgt_lang", type=str, default="", required=False) + return parser @@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule): metric_names = ["bleu"] val_metric = "bleu" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.dataset_kwargs["src_lang"] = hparams.src_lang + self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang + if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] + def calc_generative_metrics(self, preds, target) -> dict: return calculate_bleu_score(preds, target) diff --git a/examples/seq2seq/finetune_t5.sh b/examples/seq2seq/finetune_t5.sh index 1a97b08117..ed8d26634c 100755 --- a/examples/seq2seq/finetune_t5.sh +++ b/examples/seq2seq/finetune_t5.sh @@ -1,18 +1,13 @@ -export OUTPUT_DIR_NAME=t5 -export CURRENT_DIR=${PWD} -export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} - -# Make output directory if it doesn't exist -mkdir -p $OUTPUT_DIR - # Add parent directory to python path to access lightning_base.py export PYTHONPATH="../":"${PYTHONPATH}" python finetune.py \ ---data_dir=./cnn-dailymail/cnn_dm \ ---model_name_or_path=t5-large \ +--data_dir=$CNN_DIR \ --learning_rate=3e-5 \ ---train_batch_size=4 \ ---eval_batch_size=4 \ +--train_batch_size=$BS \ +--eval_batch_size=$BS \ --output_dir=$OUTPUT_DIR \ ---do_train $@ +--max_source_length=512 \ +--val_check_interval=0.1 --n_val=200 \ +--do_train --do_predict \ + $@ diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 92e75993a9..3989046e65 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -223,10 +223,30 @@ def test_finetune(model): output_dir=output_dir, do_predict=True, task=task, + src_lang="en_XX", + tgt_lang="ro_RO", + freeze_encoder=True, + freeze_embeds=True, ) assert "n_train" in args_d args = argparse.Namespace(**args_d) - main(args) + module = main(args) + + input_embeds = module.model.get_input_embeddings() + assert not input_embeds.weight.requires_grad + if model == T5_TINY: + lm_head = module.model.lm_head + assert not lm_head.weight.requires_grad + assert (lm_head.weight == input_embeds.weight).all().item() + + else: + bart = module.model.model + embed_pos = bart.decoder.embed_positions + assert not embed_pos.weight.requires_grad + assert not bart.shared.weight.requires_grad + # check that embeds are the same + assert bart.decoder.embed_tokens == bart.encoder.embed_tokens + assert bart.decoder.embed_tokens == bart.shared @pytest.mark.parametrize( @@ -239,7 +259,12 @@ def test_dataset(tok): max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) trunc_target = 4 train_dataset = SummarizationDataset( - tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, + tokenizer, + data_dir=tmp_dir, + type_path="train", + max_source_length=20, + max_target_length=trunc_target, + tgt_lang="ro_RO", ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: diff --git a/examples/seq2seq/train_mbart_cc25_enro.sh b/examples/seq2seq/train_mbart_cc25_enro.sh new file mode 100755 index 0000000000..2fd5268cd4 --- /dev/null +++ b/examples/seq2seq/train_mbart_cc25_enro.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +export PYTHONPATH="../":"${PYTHONPATH}" + +python finetune.py \ + --learning_rate=3e-5 \ + --fp16 \ + --gpus 1 \ + --do_train \ + --do_predict \ + --val_check_interval 0.1 \ + --n_val 500 \ + --adam_eps 1e-06 \ + --num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \ + --freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \ + --max_source_length=300 --max_target_length 300 --val_max_target_length=300 --test_max_target_length 300 \ + --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ + --model_name_or_path facebook/mbart-large-cc25 \ + --task translation \ + --warmup_steps 500 \ + --logger wandb --sortish_sampler \ + $@ diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 7e2b862dd8..1c3a9c1696 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -14,6 +14,8 @@ from torch import nn from torch.utils.data import Dataset, Sampler from tqdm import tqdm +from transformers import BartTokenizer + def encode_file( tokenizer, @@ -25,6 +27,7 @@ def encode_file( prefix="", tok_name="", ): + extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt") if not overwrite_cache and cache_path.exists(): try: @@ -46,8 +49,8 @@ def encode_file( max_length=max_length, padding="max_length" if pad_to_max_length else None, truncation=True, - add_prefix_space=True, return_tensors=return_tensors, + **extra_kw, ) assert tokenized.input_ids.shape[1] == max_length examples.append(tokenized) @@ -87,9 +90,14 @@ class SummarizationDataset(Dataset): n_obs=None, overwrite_cache=False, prefix="", + src_lang=None, + tgt_lang=None, ): super().__init__() + # FIXME: the rstrip logic strips all the chars, it seems. tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer") + if hasattr(tokenizer, "set_lang") and src_lang is not None: + tokenizer.set_lang(src_lang) # HACK: only applies to mbart self.source = encode_file( tokenizer, os.path.join(data_dir, type_path + ".source"), @@ -100,7 +108,8 @@ class SummarizationDataset(Dataset): ) tgt_path = os.path.join(data_dir, type_path + ".target") if hasattr(tokenizer, "set_lang"): - tokenizer.set_lang("ro_RO") # HACK: only applies to mbart + assert tgt_lang is not None, "--tgt_lang must be passed to build a translation" + tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart self.target = encode_file( tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name ) @@ -224,8 +233,8 @@ def get_git_info(): ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] -def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict: - scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True) +def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: + scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) aggregator = scoring.BootstrapAggregator() for reference_ln, output_ln in zip(reference_lns, output_lns): diff --git a/setup.cfg b/setup.cfg index aa1dfcf111..d1e67228d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ known_third_party = sacrebleu seqeval sklearn + streamlit tensorboardX tensorflow tensorflow_datasets diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 0640593d47..7378da9680 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -55,15 +55,16 @@ class BartTokenizerFast(RobertaTokenizerFast): } -_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"] +_all_mbart_models = ["facebook/mbart-large-en-ro", "facebook/mbart-large-cc25"] SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" class MBartTokenizer(XLMRobertaTokenizer): """ This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs. - Other tokenizer methods like encode do not work properly. - The tokenization method is . There is no BOS token. + Other tokenizer methods like ``encode`` do not work properly. + The tokenization method is `` `` for source language documents, and + `` ``` for target language documents. Examples:: @@ -109,24 +110,84 @@ class MBartTokenizer(XLMRobertaTokenizer): } id_to_lang_code = {v: k for k, v in lang_code_to_id.items()} cur_lang_code = lang_code_to_id["en_XX"] + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self._additional_special_tokens = list(self.lang_code_to_id.keys()) + self.reset_special_tokens() - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: - """Build model inputs from a sequence by appending eos_token_id.""" - special_tokens = [self.eos_token_id, self.cur_lang_code] + def reset_special_tokens(self) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. The special tokens depend on calling set_lang. + An MBART sequence has the following format, where ``X`` represents the sequence: + - ``input_ids`` (for encoder) ``X [eos, src_lang_code]`` + - ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]`` + BOS is never used. + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ if token_ids_1 is None: - return token_ids_0 + special_tokens + return self.prefix_tokens + token_ids_0 + self.suffix_tokens # We don't expect to process pairs, but leave the pair logic for API consistency - return token_ids_0 + token_ids_1 + special_tokens + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` methods. + + Args: + token_ids_0 (:obj:`List[int]`): + List of ids. + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Set to True if the token list is already formatted with special tokens for the model + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones def set_lang(self, lang: str) -> None: """Set the current language code in order to call tokenizer properly.""" self.cur_lang_code = self.lang_code_to_id[lang] + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id] def prepare_translation_batch( self, @@ -135,44 +196,45 @@ class MBartTokenizer(XLMRobertaTokenizer): tgt_texts: Optional[List[str]] = None, tgt_lang: str = "ro_RO", max_length: Optional[int] = None, - pad_to_max_length: bool = True, + padding: str = "longest", return_tensors: str = "pt", ) -> BatchEncoding: - """ + """Prepare a batch that can be passed directly to an instance of MBartModel. Arguments: src_texts: list of src language texts - src_lang: default en_XX (english) + src_lang: default en_XX (english), the language we are translating from tgt_texts: list of tgt language texts - tgt_lang: default ro_RO (romanian) - max_length: (None) defer to config (1024 for mbart-large-en-ro) - pad_to_max_length: (bool) + tgt_lang: default ro_RO (romanian), the language we are translating to + max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large* + padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. Returns: - dict with keys input_ids, attention_mask, decoder_input_ids, each value is a torch.Tensor. + :obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. """ if max_length is None: max_length = self.max_len self.cur_lang_code = self.lang_code_to_id[src_lang] - model_inputs: BatchEncoding = self.batch_encode_plus( + model_inputs: BatchEncoding = self( src_texts, add_special_tokens=True, return_tensors=return_tensors, max_length=max_length, - pad_to_max_length=pad_to_max_length, + padding=padding, truncation=True, ) if tgt_texts is None: return model_inputs - self.cur_lang_code = self.lang_code_to_id[tgt_lang] - decoder_inputs: BatchEncoding = self.batch_encode_plus( + self.set_lang(tgt_lang) + decoder_inputs: BatchEncoding = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, + padding=padding, max_length=max_length, - pad_to_max_length=pad_to_max_length, truncation=True, ) for k, v in decoder_inputs.items(): model_inputs[f"decoder_{k}"] = v self.cur_lang_code = self.lang_code_to_id[src_lang] + self.reset_special_tokens() # sets to src_lang return model_inputs diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 26c25806f5..5ea77c70eb 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -19,7 +19,6 @@ import unittest import timeout_decorator # noqa from transformers import is_torch_available -from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester @@ -31,7 +30,6 @@ if is_torch_available(): from transformers import ( AutoModel, AutoModelForSequenceClassification, - AutoModelForSeq2SeqLM, AutoTokenizer, BartModel, BartForConditionalGeneration, @@ -39,7 +37,6 @@ if is_torch_available(): BartForQuestionAnswering, BartConfig, BartTokenizer, - BatchEncoding, pipeline, ) from transformers.modeling_bart import ( @@ -202,140 +199,6 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): tiny(**inputs_dict) -EN_CODE = 250004 - - -@require_torch -class MBartIntegrationTests(unittest.TestCase): - src_text = [ - " UN Chief Says There Is No Military Solution in Syria", - """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""", - ] - tgt_text = [ - "Şeful ONU declară că nu există o soluţie militară în Siria", - 'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.', - ] - - expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE] - - @classmethod - def setUpClass(cls): - checkpoint_name = "facebook/mbart-large-en-ro" - cls.tokenizer = AutoTokenizer.from_pretrained(checkpoint_name) - cls.pad_token_id = 1 - return cls - - @cached_property - def model(self): - """Only load the model if needed.""" - model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-en-ro").to(torch_device) - if "cuda" in torch_device: - model = model.half() - return model - - @slow - @unittest.skip("This has been failing since June 20th at least.") - def test_enro_forward(self): - model = self.model - net_input = { - "input_ids": _long_tensor( - [ - [3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004], - [64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004], - ] - ), - "decoder_input_ids": _long_tensor( - [ - [250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1], - [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2], - ] - ), - } - net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id) - with torch.no_grad(): - logits, *other_stuff = model(**net_input) - - expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype) - result_slice = logits[0, 0, :3] - _assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE) - - @slow - def test_enro_generate(self): - batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device) - translated_tokens = self.model.generate(**batch) - decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) - self.assertEqual(self.tgt_text[0], decoded[0]) - self.assertEqual(self.tgt_text[1], decoded[1]) - - def test_mbart_enro_config(self): - mbart_models = ["facebook/mbart-large-en-ro"] - expected = {"scale_embedding": True, "output_past": True} - for name in mbart_models: - config = BartConfig.from_pretrained(name) - self.assertTrue(config.is_valid_mbart()) - for k, v in expected.items(): - try: - self.assertEqual(v, getattr(config, k)) - except AssertionError as e: - e.args += (name, k) - raise - - def test_mbart_fast_forward(self): - config = BartConfig( - vocab_size=99, - d_model=24, - encoder_layers=2, - decoder_layers=2, - encoder_attention_heads=2, - decoder_attention_heads=2, - encoder_ffn_dim=32, - decoder_ffn_dim=32, - max_position_embeddings=48, - add_final_layer_norm=True, - ) - lm_model = BartForConditionalGeneration(config).to(torch_device) - context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) - summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) - loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) - expected_shape = (*summary.shape, config.vocab_size) - self.assertEqual(logits.shape, expected_shape) - - def test_enro_tokenizer_prepare_translation_batch(self): - batch = self.tokenizer.prepare_translation_batch( - self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), - ) - self.assertIsInstance(batch, BatchEncoding) - - self.assertEqual((2, 14), batch.input_ids.shape) - self.assertEqual((2, 14), batch.attention_mask.shape) - result = batch.input_ids.tolist()[0] - self.assertListEqual(self.expected_src_tokens, result) - self.assertEqual(2, batch.decoder_input_ids[0, -2]) # EOS - - def test_enro_tokenizer_batch_encode_plus(self): - ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] - self.assertListEqual(self.expected_src_tokens, ids) - - def test_enro_tokenizer_decode_ignores_language_codes(self): - self.assertIn(250020, self.tokenizer.all_special_ids) - generated_ids = [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2] - result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) - expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) - self.assertEqual(result, expected_romanian) - self.assertNotIn(self.tokenizer.eos_token, result) - - def test_enro_tokenizer_truncation(self): - src_text = ["this is gunna be a long sentence " * 20] - assert isinstance(src_text[0], str) - desired_max_length = 10 - ids = self.tokenizer.prepare_translation_batch( - src_text, return_tensors=None, max_length=desired_max_length - ).input_ids[0] - self.assertEqual(ids[-2], 2) - self.assertEqual(ids[-1], EN_CODE) - self.assertEqual(len(ids), desired_max_length) - - @require_torch class BartHeadTests(unittest.TestCase): vocab_size = 99 diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py new file mode 100644 index 0000000000..2a50fa0688 --- /dev/null +++ b/tests/test_modeling_mbart.py @@ -0,0 +1,142 @@ +import unittest + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor + + +if is_torch_available(): + import torch + from transformers import ( + AutoModelForSeq2SeqLM, + BartConfig, + BartForConditionalGeneration, + BatchEncoding, + AutoTokenizer, + ) + + +EN_CODE = 250004 +RO_CODE = 250020 + + +@require_torch +class AbstractMBartIntegrationTest(unittest.TestCase): + + checkpoint_name = None + + @classmethod + def setUpClass(cls): + cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name) + cls.pad_token_id = 1 + return cls + + @cached_property + def model(self): + """Only load the model if needed.""" + model = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device) + if "cuda" in torch_device: + model = model.half() + return model + + +@require_torch +class MBartEnroIntegrationTest(AbstractMBartIntegrationTest): + checkpoint_name = "facebook/mbart-large-en-ro" + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""", + ] + tgt_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + 'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.', + ] + expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE] + + @slow + @unittest.skip("This has been failing since June 20th at least.") + def test_enro_forward(self): + model = self.model + net_input = { + "input_ids": _long_tensor( + [ + [3493, 3060, 621, 104064, 1810, 100, 142, 566, 13158, 6889, 5, 2, 250004], + [64511, 7, 765, 2837, 45188, 297, 4049, 237, 10, 122122, 5, 2, 250004], + ] + ), + "decoder_input_ids": _long_tensor( + [ + [250020, 31952, 144, 9019, 242307, 21980, 55749, 11, 5, 2, 1, 1], + [250020, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2], + ] + ), + } + net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id) + with torch.no_grad(): + logits, *other_stuff = model(**net_input) + + expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype) + result_slice = logits[0, 0, :3] + _assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE) + + @slow + def test_enro_generate(self): + batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device) + translated_tokens = self.model.generate(**batch) + decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) + self.assertEqual(self.tgt_text[0], decoded[0]) + self.assertEqual(self.tgt_text[1], decoded[1]) + + def test_mbart_enro_config(self): + mbart_models = ["facebook/mbart-large-en-ro"] + expected = {"scale_embedding": True, "output_past": True} + for name in mbart_models: + config = BartConfig.from_pretrained(name) + self.assertTrue(config.is_valid_mbart()) + for k, v in expected.items(): + try: + self.assertEqual(v, getattr(config, k)) + except AssertionError as e: + e.args += (name, k) + raise + + def test_mbart_fast_forward(self): + config = BartConfig( + vocab_size=99, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + add_final_layer_norm=True, + ) + lm_model = BartForConditionalGeneration(config).to(torch_device) + context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) + summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) + loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary) + expected_shape = (*summary.shape, config.vocab_size) + self.assertEqual(logits.shape, expected_shape) + + +class MBartCC25IntegrationTest(AbstractMBartIntegrationTest): + checkpoint_name = "facebook/mbart-large-cc25" + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + " I ate lunch twice yesterday", + ] + tgt_text = ["Şeful ONU declară că nu există o soluţie militară în Siria", "to be padded"] + + @unittest.skip("This test is broken, still generates english") + def test_cc25_generate(self): + inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device) + translated_tokens = self.model.generate( + input_ids=inputs["input_ids"].to(torch_device), + decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"], + ) + decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) + self.assertEqual(self.tgt_text[0], decoded[0]) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index b9c866cdc5..f0456290b9 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -903,6 +903,7 @@ class TokenizerTesterMixin: tokenizer.padding_side = "right" encoded_sequence = tokenizer.encode(sequence) sequence_length = len(encoded_sequence) + # FIXME: the next line should be padding(max_length) to avoid warning padded_sequence = tokenizer.encode( sequence, max_length=sequence_length + padding_size, pad_to_max_length=True ) diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py new file mode 100644 index 0000000000..84bc3ba8b2 --- /dev/null +++ b/tests/test_tokenization_mbart.py @@ -0,0 +1,156 @@ +import unittest + +from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer +from transformers.testing_utils import require_torch + +from .test_tokenization_common import TokenizerTesterMixin +from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE + + +EN_CODE = 250004 +RO_CODE = 250020 + + +class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = MBartTokenizer + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def test_full_tokenizer(self): + tokenizer = MBartTokenizer(SAMPLE_VOCAB, keep_accents=True) + + tokens = tokenizer.tokenize("This is a test") + self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), + [value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]], + ) + + tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") + self.assertListEqual( + tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "9", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "é", + ".", + ], + ) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, + [ + value + tokenizer.fairseq_offset + for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4] + # ^ unk: 2 + 1 = 3 unk: 2 + 1 = 3 ^ + ], + ) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual( + back_tokens, + [ + SPIECE_UNDERLINE + "I", + SPIECE_UNDERLINE + "was", + SPIECE_UNDERLINE + "b", + "or", + "n", + SPIECE_UNDERLINE + "in", + SPIECE_UNDERLINE + "", + "", + "2", + "0", + "0", + "0", + ",", + SPIECE_UNDERLINE + "and", + SPIECE_UNDERLINE + "this", + SPIECE_UNDERLINE + "is", + SPIECE_UNDERLINE + "f", + "al", + "s", + "", + ".", + ], + ) + + +@require_torch +class MBartEnroIntegrationTest(unittest.TestCase): + checkpoint_name = "facebook/mbart-large-en-ro" + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""", + ] + tgt_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + 'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.', + ] + expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE] + + @classmethod + def setUpClass(cls): + cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name) + cls.pad_token_id = 1 + return cls + + def test_enro_tokenizer_prepare_translation_batch(self): + batch = self.tokenizer.prepare_translation_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), + ) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 14), batch.input_ids.shape) + self.assertEqual((2, 14), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(self.expected_src_tokens, result) + self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS + # Test that special tokens are reset + self.assertEqual(self.tokenizer.prefix_tokens, []) + self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) + + def test_enro_tokenizer_batch_encode_plus(self): + ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] + self.assertListEqual(self.expected_src_tokens, ids) + + def test_enro_tokenizer_decode_ignores_language_codes(self): + self.assertIn(RO_CODE, self.tokenizer.all_special_ids) + generated_ids = [RO_CODE, 884, 9019, 96, 9, 916, 86792, 36, 18743, 15596, 5, 2] + result = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + expected_romanian = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True) + self.assertEqual(result, expected_romanian) + self.assertNotIn(self.tokenizer.eos_token, result) + + def test_enro_tokenizer_truncation(self): + src_text = ["this is gunna be a long sentence " * 20] + assert isinstance(src_text[0], str) + desired_max_length = 10 + ids = self.tokenizer.prepare_translation_batch( + src_text, return_tensors=None, max_length=desired_max_length + ).input_ids[0] + self.assertEqual(ids[-2], 2) + self.assertEqual(ids[-1], EN_CODE) + self.assertEqual(len(ids), desired_max_length)