From 09a2f40684f77e62d0fd8485fe9d2d610390453f Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 18 Jul 2020 13:57:33 -0400 Subject: [PATCH] Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792) Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com> --- examples/seq2seq/README.md | 46 ++---- examples/seq2seq/distillation.py | 26 +--- examples/seq2seq/finetune.py | 18 ++- examples/seq2seq/test_seq2seq_examples.py | 65 ++++++--- examples/seq2seq/utils.py | 164 ++++++++++++---------- src/transformers/tokenization_bart.py | 33 ++--- 6 files changed, 182 insertions(+), 170 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index e13ac7980e..e5a6f9da79 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -7,6 +7,15 @@ For `bertabs` instructions, see `bertabs/README.md`. ### Data +XSUM Data: +```bash +cd examples/seq2seq +wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz +tar -xzvf xsum.tar.gz +export XSUM_DIR=${PWD}/xsum +``` +this should make a directory called cnn_dm/ with files like `test.source`. +To use your own data, copy that files format. Each article to be summarized is on its own line. CNN/DailyMail data ```bash @@ -17,18 +26,6 @@ tar -xzvf cnn_dm.tgz export CNN_DIR=${PWD}/cnn_dm ``` -this should make a directory called cnn_dm/ with files like `test.source`. -To use your own data, copy that files format. Each article to be summarized is on its own line. - -XSUM Data: -```bash -cd examples/seq2seq -wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz -tar -xzvf xsum.tar.gz -export XSUM_DIR=${PWD}/xsum -``` - - WMT16 English-Romanian Translation Data: ```bash cd examples/seq2seq @@ -40,7 +37,7 @@ export ENRO_DIR=${PWD}/wmt_en_ro If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target. The `.source` files are the input, the `.target` files are the desired output. - + ### Tips and Tricks General Tips: @@ -64,6 +61,10 @@ Summarization Tips: - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). +**Update 2018-07-18** +Datasets: Seq2SeqDataset will be used for all models besides MBart, for which MBartDataset will be used.** +A new dataset is needed to support multilingual tasks. + ### Summarization Finetuning Run/modify `finetune.sh` @@ -78,8 +79,6 @@ The following command should work on a 16GB GPU: --model_name_or_path facebook/bart-large ``` - - ### Translation Finetuning First, follow the wmt_en_ro download instructions. @@ -124,23 +123,6 @@ from transformers import AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') ``` -#### XSUM Shared Task -Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration. - -Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! -```bash -WANDB_PROJECT='hf_xsum' ./finetune.sh \ - --data_dir $XSUM_DIR \ - --output_dir xsum_frozen_embs \ - --model_name_or_path facebook/bart-large \ - --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ - --num_train_epochs 6 \ - --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ - --logger_name wandb -``` - -You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) - ### Evaluation Commands To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models. diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index a683fd7e05..5cf4d6032b 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf try: from .finetune import SummarizationModule - from .initialization_utils import init_student, copy_layers - from .utils import ( - use_task_specific_params, - SummarizationDataset, - pickle_load, - freeze_params, - assert_all_frozen, - any_requires_grad, - ) from .finetune import main as ft_main + from .initialization_utils import init_student, copy_layers + from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad + except ImportError: from finetune import SummarizationModule from finetune import main as ft_main from initialization_utils import init_student, copy_layers - from utils import ( - use_task_specific_params, - SummarizationDataset, - pickle_load, - freeze_params, - assert_all_frozen, - any_requires_grad, - ) + from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad class BartSummarizationDistiller(SummarizationModule): @@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule): if self.different_encoder: copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) - def get_dataset(self, type_path) -> SummarizationDataset: - n_obs = self.n_obs[type_path] - dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs) - return dataset - def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor: if mask is not None: # mask has False at padding_idx diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index cd33892680..1afeafd042 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -21,7 +21,6 @@ try: from .utils import ( assert_all_frozen, use_task_specific_params, - SummarizationDataset, lmap, flatten_list, pickle_save, @@ -32,12 +31,17 @@ try: get_git_info, ROUGE_KEYS, calculate_bleu_score, + Seq2SeqDataset, + MBartDataset, ) + from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback except ImportError: from utils import ( + Seq2SeqDataset, + MBartDataset, + assert_all_frozen, use_task_specific_params, - SummarizationDataset, lmap, flatten_list, pickle_save, @@ -48,7 +52,6 @@ except ImportError: get_git_info, ROUGE_KEYS, calculate_bleu_score, - assert_all_frozen, ) from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback @@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer): self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers self.decoder_start_token_id = None + self.dataset_class = Seq2SeqDataset def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer): def _generative_step(self, batch: dict) -> dict: pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) + source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id) t0 = time.time() generated_ids = self.model.generate( input_ids=source_ids, @@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer): def test_epoch_end(self, outputs): return self.validation_epoch_end(outputs, prefix="test") - def get_dataset(self, type_path) -> SummarizationDataset: + def get_dataset(self, type_path) -> Seq2SeqDataset: n_obs = self.n_obs[type_path] max_target_length = self.target_lens[type_path] - dataset = SummarizationDataset( + dataset = self.dataset_class( self.tokenizer, type_path=type_path, n_obs=n_obs, @@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule): self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] + if isinstance(self.tokenizer, MBartTokenizer): + self.dataset_class = MBartDataset def calc_generative_metrics(self, preds, target) -> dict: return calculate_bleu_score(preds, target) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 61f487d654..abf9b908a6 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -9,16 +9,17 @@ from unittest.mock import patch import pytest import torch +from pytest import param from torch.utils.data import DataLoader -from transformers import AutoTokenizer +from transformers import AutoTokenizer, MBartTokenizer from transformers.testing_utils import require_multigpu from .distillation import distill_main, evaluate_checkpoint from .finetune import main from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate -from .utils import SummarizationDataset, lmap, load_json +from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json logging.basicConfig(level=logging.DEBUG) @@ -26,6 +27,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { + "label_smoothing_eps": 0.2, "logger_name": "default", "length_penalty": 0.5, "cache_dir": "", @@ -80,11 +82,11 @@ CHEAP_ARGS = { def _dump_articles(path: Path, articles: list): - with path.open("w") as f: - f.write("\n".join(articles)) + content = "\n".join(articles) + Path(path).open("w").writelines(content) -ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"] +ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."] SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] T5_TINY = "patrickvonplaten/t5-tiny-random" BART_TINY = "sshleifer/bart-tiny-random" @@ -208,7 +210,7 @@ def test_run_eval_bart(model): @pytest.mark.parametrize( - ["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] + ["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)], ) def test_finetune(model): args_d: dict = CHEAP_ARGS.copy() @@ -260,22 +262,50 @@ def test_pack_dataset(): assert orig_paths == new_paths -@pytest.mark.parametrize( - ["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] -) -def test_dataset(tok): +def test_mbart_dataset_truncation(): + tokenizer = MBartTokenizer.from_pretrained(MBART_TINY) + tmp_dir = make_test_data_dir() + max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) + max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) + trunc = 4 + src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON + train_dataset = MBartDataset( + tokenizer, + data_dir=tmp_dir, + type_path="train", + max_source_length=trunc, + max_target_length=1000, # ignored + src_lang=src_lang, + tgt_lang=tgt_lang, + ) + dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) + for batch in dataloader: + assert isinstance(batch, dict) + assert batch["attention_mask"].shape == batch["input_ids"].shape + # show that articles were trimmed. + assert batch["input_ids"].shape[1] == trunc + # show that targets are the same len + assert batch["decoder_input_ids"].shape[1] == trunc + # check language codes in correct place + assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] + assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id + assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id + assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang] + + assert max_len_target > trunc # Truncated + assert max_len_source > trunc + break # No need to test every batch + + +@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)]) +def test_summarization_dataset_truncation(tok): tokenizer = AutoTokenizer.from_pretrained(tok) tmp_dir = make_test_data_dir() max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) trunc_target = 4 - train_dataset = SummarizationDataset( - tokenizer, - data_dir=tmp_dir, - type_path="train", - max_source_length=20, - max_target_length=trunc_target, - tgt_lang="ro_RO", + train_dataset = Seq2SeqDataset( + tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: @@ -286,3 +316,4 @@ def test_dataset(tok): # show that targets were truncated assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated assert max_len_target > trunc_target # Truncated + break # No need to test every batch diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 2c1f9aebf5..38f04adee3 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -1,7 +1,9 @@ import itertools import json +import linecache import os import pickle +import warnings from logging import getLogger from pathlib import Path from typing import Callable, Dict, Iterable, List @@ -13,50 +15,20 @@ from rouge_score import rouge_scorer, scoring from sacrebleu import corpus_bleu from torch import nn from torch.utils.data import Dataset, Sampler -from tqdm import tqdm from transformers import BartTokenizer -def encode_file( - tokenizer, - data_path, - max_length, - pad_to_max_length=True, - return_tensors="pt", - overwrite_cache=False, - prefix="", - tok_name="", -): +def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} - cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt") - if not overwrite_cache and cache_path.exists(): - try: - examples = torch.load(cache_path) - assert isinstance(examples, list) - return examples - - except Exception: - print(f"failed to load from {cache_path}, retokenizing {data_path}") - data_path = Path(data_path) - - lns = lmap(str.strip, data_path.open().readlines()) - lns = [prefix + text for text in lns] - assert lns, f"found empty file at {data_path}" - examples = [] - for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"): - tokenized = tokenizer( - [text], - max_length=max_length, - padding="max_length" if pad_to_max_length else None, - truncation=True, - return_tensors=return_tensors, - **extra_kw, - ) - assert tokenized.input_ids.shape[1] == max_length - examples.append(tokenized) - torch.save(lmap(dict, examples), cache_path.open("wb")) - return examples + return tokenizer( + [line], + max_length=max_length, + padding="max_length" if pad_to_max_length else None, + truncation=True, + return_tensors=return_tensors, + **extra_kw, + ) def lmap(f: Callable, x: Iterable) -> List: @@ -80,73 +52,111 @@ def trim_batch( return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) -class SummarizationDataset(Dataset): +class Seq2SeqDataset(Dataset): def __init__( self, tokenizer, data_dir, + max_source_length, + max_target_length, type_path="train", - max_source_length=1024, - max_target_length=56, n_obs=None, - overwrite_cache=False, - prefix="", src_lang=None, tgt_lang=None, + prefix="", ): super().__init__() - # FIXME: the rstrip logic strips all the chars, it seems. - tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer") - if hasattr(tokenizer, "set_lang") and src_lang is not None: - tokenizer.set_lang(src_lang) # HACK: only applies to mbart - self.source = encode_file( - tokenizer, - os.path.join(data_dir, type_path + ".source"), - max_source_length, - overwrite_cache=overwrite_cache, - prefix=prefix, - tok_name=tok_name, - ) - tgt_path = os.path.join(data_dir, type_path + ".target") - if hasattr(tokenizer, "set_lang"): - assert tgt_lang is not None, "--tgt_lang must be passed to build a translation" - tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart - self.target = encode_file( - tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name - ) + self.src_file = Path(data_dir).joinpath(type_path + ".source") + self.tgt_file = Path(data_dir).joinpath(type_path + ".target") + self.src_lens = self.get_char_lens(self.src_file) + self.max_source_length = max_source_length + self.max_target_length = max_target_length + assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" + self.tokenizer = tokenizer + self.prefix = prefix if n_obs is not None: - self.source = self.source[:n_obs] - self.target = self.target[:n_obs] - self.pad_token_id = tokenizer.pad_token_id + self.src_lens = self.src_lens[:n_obs] + self.pad_token_id = self.tokenizer.pad_token_id + self.src_lang = src_lang + self.tgt_lang = tgt_lang def __len__(self): - return len(self.source) + return len(self.src_lens) - def __getitem__(self, index): - source_ids = self.source[index]["input_ids"].squeeze() - target_ids = self.target[index]["input_ids"].squeeze() - src_mask = self.source[index]["attention_mask"].squeeze() - return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids} + def __getitem__(self, index) -> Dict[str, torch.Tensor]: + index = index + 1 # linecache starts at 1 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") + assert source_line, f"empty source line for index {index}" + assert tgt_line, f"empty tgt line for index {index}" + source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) + target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) + + source_ids = source_inputs["input_ids"].squeeze() + target_ids = target_inputs["input_ids"].squeeze() + src_mask = source_inputs["attention_mask"].squeeze() + return { + "input_ids": source_ids, + "attention_mask": src_mask, + "decoder_input_ids": target_ids, + } @staticmethod - def trim_seq2seq_batch(batch, pad_token_id): + def get_char_lens(data_file): + return [len(x) for x in Path(data_file).open().readlines()] + + @staticmethod + def trim_seq2seq_batch(batch, pad_token_id) -> tuple: y = trim_batch(batch["decoder_input_ids"], pad_token_id) source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) return source_ids, source_mask, y - def collate_fn(self, batch) -> dict: + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: input_ids = torch.stack([x["input_ids"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch]) target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) pad_token_id = self.pad_token_id y = trim_batch(target_ids, pad_token_id) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) - batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y} + batch = { + "input_ids": source_ids, + "attention_mask": source_mask, + "decoder_input_ids": y, + } return batch def make_sortish_sampler(self, batch_size): - lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source] - return SortishSampler(lens, batch_size) + return SortishSampler(self.src_lens, batch_size) + + +class MBartDataset(Seq2SeqDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.max_source_length != self.max_target_length: + warnings.warn( + f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides." + ) + + def __getitem__(self, index) -> Dict[str, str]: + index = index + 1 # linecache starts at 1 + source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") + tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") + assert source_line, f"empty source line for index {index}" + assert tgt_line, f"empty tgt line for index {index}" + return { + "tgt_texts": source_line, + "src_texts": tgt_line, + } + + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: + batch_encoding = self.tokenizer.prepare_translation_batch( + [x["src_texts"] for x in batch], + src_lang=self.src_lang, + tgt_texts=[x["tgt_texts"] for x in batch], + tgt_lang=self.tgt_lang, + max_length=self.max_source_length, + ) + return batch_encoding.data class SortishSampler(Sampler): diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 2b6f7240c5..e23abcb82a 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -118,12 +118,7 @@ class MBartTokenizer(XLMRobertaTokenizer): self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self._additional_special_tokens = list(self.lang_code_to_id.keys()) - self.reset_special_tokens() - - def reset_special_tokens(self) -> None: - """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" - self.prefix_tokens = [] - self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX")) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -183,12 +178,6 @@ class MBartTokenizer(XLMRobertaTokenizer): return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones - def set_lang(self, lang: str) -> None: - """Set the current language code in order to call tokenizer properly.""" - self.cur_lang_code = self.lang_code_to_id[lang] - self.prefix_tokens = [self.cur_lang_code] - self.suffix_tokens = [self.eos_token_id] - def prepare_translation_batch( self, src_texts: List[str], @@ -215,7 +204,7 @@ class MBartTokenizer(XLMRobertaTokenizer): """ if max_length is None: max_length = self.max_len - self.cur_lang_code = self.lang_code_to_id[src_lang] + self.set_src_lang_special_tokens(src_lang) model_inputs: BatchEncoding = self( src_texts, add_special_tokens=True, @@ -227,7 +216,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ) if tgt_texts is None: return model_inputs - self.set_lang(tgt_lang) + self.set_tgt_lang_special_tokens(tgt_lang) decoder_inputs: BatchEncoding = self( tgt_texts, add_special_tokens=True, @@ -239,6 +228,18 @@ class MBartTokenizer(XLMRobertaTokenizer): ) for k, v in decoder_inputs.items(): model_inputs[f"decoder_{k}"] = v - self.cur_lang_code = self.lang_code_to_id[src_lang] - self.reset_special_tokens() # sets to src_lang + + self.set_src_lang_special_tokens(src_lang) # sets to src_lang return model_inputs + + def set_src_lang_special_tokens(self, src_lang) -> None: + """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, cur_lang_code].""" + self.cur_lang_code = self.lang_code_to_id[src_lang] + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id, self.cur_lang_code] + + def set_tgt_lang_special_tokens(self, lang: str) -> None: + """Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos].""" + self.cur_lang_code = self.lang_code_to_id[lang] + self.prefix_tokens = [self.cur_lang_code] + self.suffix_tokens = [self.eos_token_id]