prepare_seq2seq_batch makes labels/ decoder_input_ids made later. (#6654)
* broken test * batch parity * tests pass * boom boom * boom boom * split out bart tokenizer tests * fix tests * boom boom * Fixed dataset bug * Fix marian * Undo extra * Get marian working * Fix t5 tok tests * Test passing * Cleanup * better assert msg * require torch * Fix mbart tests * undo extra decoder_attn_mask change * Fix import * pegasus tokenizer can ignore src_lang kwargs * unused kwarg test cov * boom boom * add todo for pegasus issue * cover one word translation edge case * Cleanup * doc
This commit is contained in:
@@ -71,8 +71,8 @@ Summarization Tips:
|
|||||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||||
|
|
||||||
**Update 2018-07-18**
|
**Update 2018-07-18**
|
||||||
Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_seq2seq_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.**
|
Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used.
|
||||||
A new dataset is needed to support multilingual tasks.
|
Future work/help wanted: A new dataset to support multilingual tasks.
|
||||||
|
|
||||||
|
|
||||||
### Command Line Options
|
### Command Line Options
|
||||||
@@ -106,7 +106,7 @@ The following command should work on a 16GB GPU:
|
|||||||
--train_batch_size=1 \
|
--train_batch_size=1 \
|
||||||
--eval_batch_size=1 \
|
--eval_batch_size=1 \
|
||||||
--output_dir=xsum_results \
|
--output_dir=xsum_results \
|
||||||
--num_train_epochs 1 \
|
--num_train_epochs 6 \
|
||||||
--model_name_or_path facebook/bart-large
|
--model_name_or_path facebook/bart-large
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
from lightning_base import generic_train
|
from lightning_base import generic_train
|
||||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -22,6 +24,7 @@ try:
|
|||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
|
label_smoothed_nll_loss,
|
||||||
pickle_load,
|
pickle_load,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
)
|
)
|
||||||
@@ -34,12 +37,15 @@ except ImportError:
|
|||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
freeze_params,
|
freeze_params,
|
||||||
|
label_smoothed_nll_loss,
|
||||||
pickle_load,
|
pickle_load,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BartSummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
|
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
||||||
|
|
||||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
@@ -160,22 +166,32 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
def _step(self, batch):
|
def _step(self, batch):
|
||||||
# assert is_frozen(self.teacher)
|
# assert is_frozen(self.teacher)
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
||||||
decoder_input_ids = y[:, :-1].contiguous()
|
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
||||||
labels = y[:, 1:].clone()
|
|
||||||
labels[y[:, 1:] == pad_token_id] = -100
|
|
||||||
# noinspection PyCallingNonCallable
|
# noinspection PyCallingNonCallable
|
||||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=src_mask,
|
attention_mask=src_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
labels=labels,
|
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
)
|
use_cache=False,
|
||||||
|
) # TODO(@sshleifer): return_dict=True cleanup
|
||||||
|
|
||||||
|
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||||
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
|
if self.hparams.label_smoothing == 0:
|
||||||
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
|
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||||
|
else:
|
||||||
|
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
|
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||||
|
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
def zero_tensor():
|
def zero_tensor():
|
||||||
return torch.tensor(0.0).type_as(sloss)
|
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||||
|
|
||||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||||
if self.different_encoder:
|
if self.different_encoder:
|
||||||
@@ -199,29 +215,26 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
attention_mask=src_mask,
|
attention_mask=src_mask,
|
||||||
encoder_outputs=teacher_enc_outputs,
|
encoder_outputs=teacher_enc_outputs,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
lm_labels=labels,
|
lm_labels=tgt_ids,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
|
||||||
if self.alpha_hid > 0:
|
if self.alpha_hid > 0:
|
||||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||||
|
|
||||||
blended_loss = (
|
blended_loss = (
|
||||||
self.alpha_ce * loss_ce
|
self.alpha_ce * loss_ce
|
||||||
+ self.alpha_mlm * sloss
|
+ self.alpha_mlm * student_lm_loss
|
||||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||||
)
|
)
|
||||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||||
|
|
||||||
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
||||||
assert not isinstance(
|
msg = "expected list or tuple for hidden_states, got tensor of shape: "
|
||||||
hidden_states, torch.Tensor
|
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
|
||||||
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
|
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
||||||
assert not isinstance(
|
|
||||||
hidden_states_T, torch.Tensor
|
|
||||||
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
|
|
||||||
mask = attention_mask.to(hidden_states[0])
|
mask = attention_mask.to(hidden_states[0])
|
||||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||||
hidden_losses = [
|
hidden_losses = [
|
||||||
@@ -233,7 +246,7 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||||||
|
|
||||||
|
|
||||||
def add_distill_args(parser):
|
def add_distill_args(parser):
|
||||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
parser.add_argument("--teacher", type=str)
|
||||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||||
@@ -245,8 +258,9 @@ def add_distill_args(parser):
|
|||||||
|
|
||||||
|
|
||||||
class BartTranslationDistiller(BartSummarizationDistiller):
|
class BartTranslationDistiller(BartSummarizationDistiller):
|
||||||
|
"""Supports Mbart, Marian, other models that inherit from Bart."""
|
||||||
|
|
||||||
mode = "translation"
|
mode = "translation"
|
||||||
loss_names = ["loss"]
|
|
||||||
metric_names = ["bleu"]
|
metric_names = ["bleu"]
|
||||||
val_metric = "bleu"
|
val_metric = "bleu"
|
||||||
|
|
||||||
@@ -368,7 +382,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
|||||||
attention_mask=source_mask,
|
attention_mask=source_mask,
|
||||||
encoder_outputs=teacher_enc_outputs,
|
encoder_outputs=teacher_enc_outputs,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
lm_labels=labels,
|
labels=labels,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
@@ -402,6 +416,7 @@ def create_module(args):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||||
|
# TODO(SS): DELETE?
|
||||||
exp_dir = ckpt_path.parent
|
exp_dir = ckpt_path.parent
|
||||||
if dest_dir is None:
|
if dest_dir is None:
|
||||||
dest_dir = exp_dir
|
dest_dir = exp_dir
|
||||||
@@ -424,33 +439,40 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
|||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|
||||||
|
|
||||||
def get_layers_to_copy(n_to_get, tot):
|
LAYERS_TO_COPY = {
|
||||||
all_layers = list(range(tot))
|
# maps num layers in student -> which teacher layers to copy.
|
||||||
if tot == 12: # Alternating for special cases
|
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
|
||||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
12: {
|
||||||
1: [0],
|
1: [0],
|
||||||
2: [0, 6],
|
2: [0, 6],
|
||||||
3: [0, 6, 11],
|
3: [0, 6, 11],
|
||||||
4: [0, 4, 8, 11],
|
4: [0, 4, 8, 11],
|
||||||
6: [0, 2, 4, 7, 9, 11],
|
6: [0, 2, 4, 7, 9, 11],
|
||||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||||
12: all_layers,
|
12: list(range(12)),
|
||||||
}
|
},
|
||||||
return layers_to_copy[n_to_get]
|
16: { # maps num layers in student -> which teacher layers to copy
|
||||||
elif tot == 16:
|
1: [0],
|
||||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
2: [0, 8],
|
||||||
1: [0],
|
3: [0, 8, 15],
|
||||||
2: [0, 8],
|
4: [0, 5, 10, 15],
|
||||||
3: [0, 8, 15],
|
6: [0, 3, 6, 9, 12, 15],
|
||||||
4: [0, 5, 10, 15],
|
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
||||||
6: [0, 3, 6, 9, 12, 15],
|
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
||||||
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
16: list(range(16)),
|
||||||
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
},
|
||||||
16: all_layers,
|
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
||||||
}
|
}
|
||||||
return layers_to_copy[n_to_get]
|
|
||||||
else:
|
|
||||||
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
|
def get_layers_to_copy(n_student, n_teacher):
|
||||||
|
try:
|
||||||
|
return LAYERS_TO_COPY[n_teacher][n_student]
|
||||||
|
except KeyError:
|
||||||
|
warnings.warn(
|
||||||
|
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||||
|
)
|
||||||
|
return list(range(n_student))
|
||||||
|
|
||||||
|
|
||||||
def distill_main(args):
|
def distill_main(args):
|
||||||
|
|||||||
@@ -13,15 +13,16 @@ import torch
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
|
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
|
LegacySeq2SeqDataset,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
TranslationDataset,
|
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
@@ -39,8 +40,8 @@ except ImportError:
|
|||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
from utils import (
|
from utils import (
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
|
LegacySeq2SeqDataset,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
TranslationDataset,
|
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
calculate_rouge,
|
calculate_rouge,
|
||||||
@@ -102,14 +103,13 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||||
self.num_workers = hparams.num_workers
|
self.num_workers = hparams.num_workers
|
||||||
self.decoder_start_token_id = None
|
self.decoder_start_token_id = None # default to config
|
||||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||||
if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer):
|
self.dataset_class = (
|
||||||
self.dataset_class = TranslationDataset
|
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||||
else:
|
)
|
||||||
self.dataset_class = Seq2SeqDataset
|
|
||||||
|
|
||||||
def freeze_embeds(self):
|
def freeze_embeds(self):
|
||||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||||
@@ -134,27 +134,25 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
def _step(self, batch: dict) -> Tuple:
|
def _step(self, batch: dict) -> Tuple:
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
|
||||||
|
tgt_ids = batch["labels"]
|
||||||
if isinstance(self.model, T5ForConditionalGeneration):
|
if isinstance(self.model, T5ForConditionalGeneration):
|
||||||
decoder_input_ids = self.model._shift_right(target_ids)
|
decoder_input_ids = self.model._shift_right(tgt_ids)
|
||||||
lm_labels = target_ids
|
|
||||||
else:
|
else:
|
||||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
||||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
|
||||||
|
|
||||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
|
||||||
|
|
||||||
|
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||||
|
lm_logits = outputs[0]
|
||||||
if self.hparams.label_smoothing == 0:
|
if self.hparams.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
lm_logits = outputs[0]
|
|
||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1))
|
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1)
|
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
loss, nll_loss = label_smoothed_nll_loss(
|
loss, nll_loss = label_smoothed_nll_loss(
|
||||||
lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||||
)
|
)
|
||||||
return (loss,)
|
return (loss,)
|
||||||
|
|
||||||
@@ -167,7 +165,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
|
|
||||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
# tokens per batch
|
# tokens per batch
|
||||||
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum()
|
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
||||||
return {"loss": loss_tensors[0], "log": logs}
|
return {"loss": loss_tensors[0], "log": logs}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx) -> Dict:
|
def validation_step(self, batch, batch_idx) -> Dict:
|
||||||
@@ -204,7 +202,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
)
|
)
|
||||||
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
target: List[str] = self.ids_to_clean_text(batch["labels"])
|
||||||
loss_tensors = self._step(batch)
|
loss_tensors = self._step(batch)
|
||||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||||
rouge: Dict = self.calc_generative_metrics(preds, target)
|
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||||
|
|||||||
@@ -132,4 +132,6 @@ def run_generate():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Usage for MT:
|
||||||
|
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
|
||||||
run_generate()
|
run_generate()
|
||||||
|
|||||||
@@ -10,18 +10,18 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from pytest import param
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import lightning_base
|
import lightning_base
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu
|
||||||
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
from .distillation import distill_main, evaluate_checkpoint
|
||||||
from .finetune import SummarizationModule, main
|
from .finetune import SummarizationModule, main
|
||||||
from .pack_dataset import pack_data_dir
|
from .pack_dataset import pack_data_dir
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
from .run_eval import generate_summaries_or_translations, run_generate
|
||||||
from .utils import Seq2SeqDataset, TranslationDataset, label_smoothed_nll_loss, lmap, load_json
|
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -452,18 +452,27 @@ def test_pack_dataset():
|
|||||||
assert orig_paths == new_paths
|
assert orig_paths == new_paths
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["tok_name"], [pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)])
|
@pytest.mark.parametrize(
|
||||||
def test_mbart_dataset_truncation(tok_name):
|
["tok_name"],
|
||||||
|
[
|
||||||
|
pytest.param(MBART_TINY),
|
||||||
|
pytest.param(MARIAN_TINY),
|
||||||
|
pytest.param(T5_TINY),
|
||||||
|
pytest.param(BART_TINY),
|
||||||
|
pytest.param("google/pegasus-xsum"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_seq2seq_dataset_truncation(tok_name):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
max_src_len = 4
|
max_src_len = 4
|
||||||
max_tgt_len = 8
|
max_tgt_len = 8
|
||||||
assert max_len_target > max_src_len # Truncated
|
assert max_len_target > max_src_len # Will be truncated
|
||||||
assert max_len_source > max_src_len
|
assert max_len_source > max_src_len # Will be truncated
|
||||||
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
src_lang, tgt_lang = "ro_RO", "de_DE" # ignored for all but mbart, but never causes error.
|
||||||
train_dataset = TranslationDataset(
|
train_dataset = Seq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
@@ -479,10 +488,11 @@ def test_mbart_dataset_truncation(tok_name):
|
|||||||
# show that articles were trimmed.
|
# show that articles were trimmed.
|
||||||
assert batch["input_ids"].shape[1] == max_src_len
|
assert batch["input_ids"].shape[1] == max_src_len
|
||||||
# show that targets are the same len
|
# show that targets are the same len
|
||||||
assert batch["decoder_input_ids"].shape[1] == max_tgt_len
|
assert batch["labels"].shape[1] == max_tgt_len
|
||||||
if tok_name == MARIAN_TINY:
|
if tok_name != MBART_TINY:
|
||||||
continue
|
continue
|
||||||
# check language codes in correct place
|
# check language codes in correct place
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], tokenizer.pad_token_id)
|
||||||
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
||||||
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
||||||
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
||||||
@@ -491,14 +501,14 @@ def test_mbart_dataset_truncation(tok_name):
|
|||||||
break # No need to test every batch
|
break # No need to test every batch
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
|
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
|
||||||
def test_summarization_dataset_truncation(tok):
|
def test_legacy_dataset_truncation(tok):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||||
tmp_dir = make_test_data_dir()
|
tmp_dir = make_test_data_dir()
|
||||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||||
trunc_target = 4
|
trunc_target = 4
|
||||||
train_dataset = Seq2SeqDataset(
|
train_dataset = LegacySeq2SeqDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
data_dir=tmp_dir,
|
data_dir=tmp_dir,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
@@ -512,6 +522,6 @@ def test_summarization_dataset_truncation(tok):
|
|||||||
assert batch["input_ids"].shape[1] == max_len_source
|
assert batch["input_ids"].shape[1] == max_len_source
|
||||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||||
# show that targets were truncated
|
# show that targets were truncated
|
||||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
assert batch["labels"].shape[1] == trunc_target # Truncated
|
||||||
assert max_len_target > trunc_target # Truncated
|
assert max_len_target > trunc_target # Truncated
|
||||||
break # No need to test every batch
|
break # No need to test every batch
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import json
|
|||||||
import linecache
|
import linecache
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import warnings
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, List
|
from typing import Callable, Dict, Iterable, List
|
||||||
@@ -41,6 +40,7 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
|||||||
|
|
||||||
|
|
||||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||||
|
"""Only used by LegacyDataset"""
|
||||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||||
return tokenizer(
|
return tokenizer(
|
||||||
[line],
|
[line],
|
||||||
@@ -75,7 +75,7 @@ def trim_batch(
|
|||||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqDataset(Dataset):
|
class AbstractSeq2SeqDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -102,11 +102,28 @@ class Seq2SeqDataset(Dataset):
|
|||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.src_lang = src_lang
|
self.src_lang = src_lang
|
||||||
self.tgt_lang = tgt_lang
|
self.tgt_lang = tgt_lang
|
||||||
|
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.src_lens)
|
return len(self.src_lens)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_char_lens(data_file):
|
||||||
|
return [len(x) for x in Path(data_file).open().readlines()]
|
||||||
|
|
||||||
|
def make_sortish_sampler(self, batch_size):
|
||||||
|
return SortishSampler(self.src_lens, batch_size)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
raise NotImplementedError("You must implement this")
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
raise NotImplementedError("You must implement this")
|
||||||
|
|
||||||
|
|
||||||
|
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Call tokenizer on src and tgt_lines"""
|
||||||
index = index + 1 # linecache starts at 1
|
index = index + 1 # linecache starts at 1
|
||||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||||
@@ -121,42 +138,27 @@ class Seq2SeqDataset(Dataset):
|
|||||||
return {
|
return {
|
||||||
"input_ids": source_ids,
|
"input_ids": source_ids,
|
||||||
"attention_mask": src_mask,
|
"attention_mask": src_mask,
|
||||||
"decoder_input_ids": target_ids,
|
"labels": target_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_char_lens(data_file):
|
|
||||||
return [len(x) for x in Path(data_file).open().readlines()]
|
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
target_ids = torch.stack([x["labels"] for x in batch])
|
||||||
pad_token_id = self.pad_token_id
|
pad_token_id = self.pad_token_id
|
||||||
y = trim_batch(target_ids, pad_token_id)
|
y = trim_batch(target_ids, pad_token_id)
|
||||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||||
batch = {
|
batch = {
|
||||||
"input_ids": source_ids,
|
"input_ids": source_ids,
|
||||||
"attention_mask": source_mask,
|
"attention_mask": source_mask,
|
||||||
"decoder_input_ids": y,
|
"labels": y,
|
||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size):
|
|
||||||
return SortishSampler(self.src_lens, batch_size)
|
|
||||||
|
|
||||||
|
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||||
class TranslationDataset(Seq2SeqDataset):
|
|
||||||
"""A dataset that calls prepare_seq2seq_batch."""
|
"""A dataset that calls prepare_seq2seq_batch."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
if self.max_source_length != self.max_target_length:
|
|
||||||
warnings.warn(
|
|
||||||
f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. "
|
|
||||||
f"Imbalanced sequence lengths may be undesired for translation tasks"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, index) -> Dict[str, str]:
|
def __getitem__(self, index) -> Dict[str, str]:
|
||||||
index = index + 1 # linecache starts at 1
|
index = index + 1 # linecache starts at 1
|
||||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||||
@@ -169,6 +171,7 @@ class TranslationDataset(Seq2SeqDataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Call prepare_seq2seq_batch."""
|
||||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||||
[x["src_texts"] for x in batch],
|
[x["src_texts"] for x in batch],
|
||||||
src_lang=self.src_lang,
|
src_lang=self.src_lang,
|
||||||
@@ -176,6 +179,8 @@ class TranslationDataset(Seq2SeqDataset):
|
|||||||
tgt_lang=self.tgt_lang,
|
tgt_lang=self.tgt_lang,
|
||||||
max_length=self.max_source_length,
|
max_length=self.max_source_length,
|
||||||
max_target_length=self.max_target_length,
|
max_target_length=self.max_target_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_prefix_space=self.add_prefix_space,
|
||||||
)
|
)
|
||||||
return batch_encoding.data
|
return batch_encoding.data
|
||||||
|
|
||||||
@@ -276,7 +281,11 @@ def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer
|
|||||||
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# Utilities for freezing parameters and checking whether they are frozen
|
||||||
|
|
||||||
|
|
||||||
def freeze_params(model: nn.Module):
|
def freeze_params(model: nn.Module):
|
||||||
|
"""Set requires_grad=False for each of model.parameters()"""
|
||||||
for par in model.parameters():
|
for par in model.parameters():
|
||||||
par.requires_grad = False
|
par.requires_grad = False
|
||||||
|
|
||||||
|
|||||||
@@ -151,6 +151,9 @@ def _prepare_bart_decoder_inputs(
|
|||||||
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
||||||
else:
|
else:
|
||||||
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
||||||
|
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
|
||||||
|
# never mask leading token, even if it is pad
|
||||||
|
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
|
||||||
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
causal_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
|
||||||
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
dtype=causal_mask_dtype, device=decoder_input_ids.device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -636,7 +636,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||||
|
|
||||||
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100"
|
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ _all_bart_models = [
|
|||||||
"facebook/bart-large-cnn",
|
"facebook/bart-large-cnn",
|
||||||
"facebook/bart-large-xsum",
|
"facebook/bart-large-xsum",
|
||||||
"yjernite/bart_eli5",
|
"yjernite/bart_eli5",
|
||||||
|
# This is not exhaustive: see https://huggingface.co/models?filter=bart
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -117,6 +118,8 @@ class BartTokenizer(RobertaTokenizer):
|
|||||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||||
"""
|
"""
|
||||||
|
kwargs.pop("src_lang", None)
|
||||||
|
kwargs.pop("tgt_lang", None)
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = self.model_max_length
|
max_length = self.model_max_length
|
||||||
model_inputs: BatchEncoding = self(
|
model_inputs: BatchEncoding = self(
|
||||||
@@ -133,7 +136,7 @@ class BartTokenizer(RobertaTokenizer):
|
|||||||
# Process tgt_texts
|
# Process tgt_texts
|
||||||
if max_target_length is None:
|
if max_target_length is None:
|
||||||
max_target_length = max_length
|
max_target_length = max_length
|
||||||
decoder_inputs: BatchEncoding = self(
|
labels = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
@@ -141,10 +144,8 @@ class BartTokenizer(RobertaTokenizer):
|
|||||||
max_length=max_target_length,
|
max_length=max_target_length,
|
||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)["input_ids"]
|
||||||
for k, v in decoder_inputs.items():
|
model_inputs["labels"] = labels
|
||||||
model_inputs[f"decoder_{k}"] = v
|
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@@ -245,7 +246,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||||||
# Process tgt_texts
|
# Process tgt_texts
|
||||||
if max_target_length is None:
|
if max_target_length is None:
|
||||||
max_target_length = max_length
|
max_target_length = max_length
|
||||||
decoder_inputs: BatchEncoding = self(
|
labels = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
@@ -253,8 +254,6 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||||||
max_length=max_target_length,
|
max_length=max_target_length,
|
||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)["input_ids"]
|
||||||
for k, v in decoder_inputs.items():
|
model_inputs["labels"] = labels
|
||||||
model_inputs[f"decoder_{k}"] = v
|
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|||||||
@@ -160,9 +160,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
tokenizer_kwargs["max_length"] = max_target_length
|
tokenizer_kwargs["max_length"] = max_target_length
|
||||||
|
|
||||||
self.current_spm = self.spm_target
|
self.current_spm = self.spm_target
|
||||||
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||||
for k, v in decoder_inputs.items():
|
|
||||||
model_inputs[f"decoder_{k}"] = v
|
|
||||||
self.current_spm = self.spm_source
|
self.current_spm = self.spm_source
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|||||||
@@ -98,32 +98,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
|
||||||
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
|
||||||
An MBART sequence has the following format, where ``X`` represents the sequence:
|
|
||||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
|
||||||
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
|
||||||
BOS is never used.
|
|
||||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids_0 (:obj:`List[int]`):
|
|
||||||
List of IDs to which the special tokens will be added
|
|
||||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
|
||||||
Optional second list of IDs for sequence pairs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
|
||||||
"""
|
|
||||||
if token_ids_1 is None:
|
|
||||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
|
||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
|
||||||
|
|
||||||
def get_special_tokens_mask(
|
def get_special_tokens_mask(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
@@ -156,6 +130,32 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||||
|
by concatenating and adding special tokens. The special tokens depend on calling set_lang.
|
||||||
|
An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||||
|
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||||
|
- ``decoder_input_ids``: (for decoder) ``[tgt_lang_code] X [eos]``
|
||||||
|
BOS is never used.
|
||||||
|
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs to which the special tokens will be added
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||||
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
@@ -251,7 +251,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
if max_target_length is None:
|
if max_target_length is None:
|
||||||
max_target_length = max_length
|
max_target_length = max_length
|
||||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||||
decoder_inputs: BatchEncoding = self(
|
|
||||||
|
labels = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
@@ -259,10 +260,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
max_length=max_target_length,
|
max_length=max_target_length,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)["input_ids"]
|
||||||
for k, v in decoder_inputs.items():
|
model_inputs["labels"] = labels
|
||||||
model_inputs[f"decoder_{k}"] = v
|
|
||||||
|
|
||||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@@ -275,5 +274,5 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||||
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
"""Reset the special tokens to the target language setting. Prefix [tgt_lang_code], suffix =[eos]."""
|
||||||
self.cur_lang_code = self.lang_code_to_id[lang]
|
self.cur_lang_code = self.lang_code_to_id[lang]
|
||||||
self.prefix_tokens = [self.cur_lang_code]
|
self.prefix_tokens = []
|
||||||
self.suffix_tokens = [self.eos_token_id]
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
return_tensors: str = "pt",
|
return_tensors: str = "pt",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
padding="longest",
|
padding="longest",
|
||||||
|
**unused,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
"""
|
"""
|
||||||
Prepare model inputs for summarization or translation.
|
Prepare model inputs for summarization or translation.
|
||||||
@@ -133,7 +134,9 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
if max_target_length is not None:
|
if max_target_length is not None:
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
tokenizer_kwargs["max_length"] = max_target_length
|
||||||
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
# TODO(@sshleifer): maybe tgt_texts = [self.pad_token + t for t in tgt_texts] # add decoder_start_token_id
|
||||||
for k, v in decoder_inputs.items():
|
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||||
model_inputs[f"decoder_{k}"] = v
|
model_inputs["labels"] = labels
|
||||||
|
# for k, v in decoder_inputs.items():
|
||||||
|
# model_inputs[f"decoder_{k}"] = v
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = self.max_len
|
max_length = self.max_len
|
||||||
self.prefix_tokens = []
|
self.prefix_tokens = []
|
||||||
model_inputs: BatchEncoding = self(
|
model_inputs = self(
|
||||||
src_texts,
|
src_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
@@ -362,7 +362,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
max_target_length = max_length
|
max_target_length = max_length
|
||||||
# set prefix_tokens for target text
|
# set prefix_tokens for target text
|
||||||
self.prefix_tokens = [self.pad_token_id]
|
self.prefix_tokens = [self.pad_token_id]
|
||||||
decoder_inputs: BatchEncoding = self(
|
labels_and_decoder_mask = self(
|
||||||
tgt_texts,
|
tgt_texts,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
@@ -371,8 +371,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
truncation=truncation,
|
truncation=truncation,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
for k, v in decoder_inputs.items():
|
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||||
model_inputs[f"decoder_{k}"] = v
|
model_inputs["decoder_attention_mask"] = labels_and_decoder_mask["attention_mask"]
|
||||||
|
|
||||||
self.prefix_tokens = []
|
self.prefix_tokens = []
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import unittest
|
|||||||
|
|
||||||
import timeout_decorator # noqa
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
from transformers import BatchEncoding, is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
@@ -496,7 +496,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
def test_xsum_summarization_same_as_fairseq(self):
|
def test_xsum_summarization_same_as_fairseq(self):
|
||||||
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
|
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum").to(torch_device)
|
||||||
self.assertFalse(model.config.is_valid_mbart())
|
self.assertFalse(model.config.is_valid_mbart())
|
||||||
tok = BartTokenizer.from_pretrained("facebook/bart-large")
|
tok = self.default_tokenizer
|
||||||
|
|
||||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
|
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
|
||||||
dct = tok.batch_encode_plus(
|
dct = tok.batch_encode_plus(
|
||||||
@@ -585,84 +585,6 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
|
||||||
# TODO(SS): add test case that hits max_length
|
# TODO(SS): add test case that hits max_length
|
||||||
|
|
||||||
def test_prepare_seq2seq_batch(self):
|
|
||||||
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
|
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
|
||||||
tgt_text = [
|
|
||||||
"Summary of the text.",
|
|
||||||
"Another summary.",
|
|
||||||
]
|
|
||||||
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
|
|
||||||
|
|
||||||
for tokenizer in tokenizers:
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
|
||||||
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
|
|
||||||
)
|
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
|
||||||
|
|
||||||
self.assertEqual((2, 10), batch.input_ids.shape)
|
|
||||||
self.assertEqual((2, 10), batch.attention_mask.shape)
|
|
||||||
result = batch.input_ids.tolist()[0]
|
|
||||||
self.assertListEqual(expected_src_tokens, result)
|
|
||||||
# Test that special tokens are reset
|
|
||||||
|
|
||||||
def test_empty_target_text(self):
|
|
||||||
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
|
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
|
||||||
for tokenizer in tokenizers:
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
|
|
||||||
# check if input_ids are returned and no decoder_input_ids
|
|
||||||
self.assertIn("input_ids", batch)
|
|
||||||
self.assertIn("attention_mask", batch)
|
|
||||||
self.assertNotIn("decoder_input_ids", batch)
|
|
||||||
self.assertNotIn("decoder_attention_mask", batch)
|
|
||||||
|
|
||||||
def test_max_target_length(self):
|
|
||||||
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
|
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
|
||||||
tgt_text = [
|
|
||||||
"Summary of the text.",
|
|
||||||
"Another summary.",
|
|
||||||
]
|
|
||||||
for tokenizer in tokenizers:
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
|
||||||
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
|
|
||||||
)
|
|
||||||
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
|
|
||||||
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
|
||||||
|
|
||||||
# test None max_target_length
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
|
||||||
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
|
|
||||||
)
|
|
||||||
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
|
|
||||||
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
|
||||||
|
|
||||||
def test_outputs_not_longer_than_maxlen(self):
|
|
||||||
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
|
|
||||||
|
|
||||||
for tokenizer in tokenizers:
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
|
||||||
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
|
|
||||||
)
|
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
|
||||||
self.assertEqual(batch.input_ids.shape, (2, 1024))
|
|
||||||
|
|
||||||
def test_special_tokens(self):
|
|
||||||
tokenizers = [self.default_tokenizer, self.default_tokenizer_fast]
|
|
||||||
src_text = ["A long paragraph for summrization."]
|
|
||||||
tgt_text = [
|
|
||||||
"Summary of the text.",
|
|
||||||
]
|
|
||||||
for tokenizer in tokenizers:
|
|
||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
|
|
||||||
input_ids = batch["input_ids"]
|
|
||||||
decoder_input_ids = batch["decoder_input_ids"]
|
|
||||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
|
||||||
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
|
||||||
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
|
|
||||||
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||||
|
|||||||
145
tests/test_tokenization_bart.py
Normal file
145
tests/test_tokenization_bart.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import BartTokenizer, BartTokenizerFast, BatchEncoding
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
from transformers.tokenization_roberta import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
tokenizer_class = BartTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
vocab = [
|
||||||
|
"l",
|
||||||
|
"o",
|
||||||
|
"w",
|
||||||
|
"e",
|
||||||
|
"r",
|
||||||
|
"s",
|
||||||
|
"t",
|
||||||
|
"i",
|
||||||
|
"d",
|
||||||
|
"n",
|
||||||
|
"\u0120",
|
||||||
|
"\u0120l",
|
||||||
|
"\u0120n",
|
||||||
|
"\u0120lo",
|
||||||
|
"\u0120low",
|
||||||
|
"er",
|
||||||
|
"\u0120lowest",
|
||||||
|
"\u0120newer",
|
||||||
|
"\u0120wider",
|
||||||
|
"<unk>",
|
||||||
|
]
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||||
|
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||||
|
|
||||||
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||||
|
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||||
|
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
kwargs.update(self.special_tokens_map)
|
||||||
|
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
|
kwargs.update(self.special_tokens_map)
|
||||||
|
return BartTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_input_output_texts(self, tokenizer):
|
||||||
|
return "lower newer", "lower newer"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def default_tokenizer(self):
|
||||||
|
return BartTokenizer.from_pretrained("facebook/bart-large")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def default_tokenizer_fast(self):
|
||||||
|
return BartTokenizerFast.from_pretrained("facebook/bart-large")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_prepare_seq2seq_batch(self):
|
||||||
|
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
||||||
|
tgt_text = [
|
||||||
|
"Summary of the text.",
|
||||||
|
"Another summary.",
|
||||||
|
]
|
||||||
|
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
|
||||||
|
|
||||||
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
|
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
|
||||||
|
)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
self.assertEqual((2, 10), batch.input_ids.shape)
|
||||||
|
self.assertEqual((2, 10), batch.attention_mask.shape)
|
||||||
|
result = batch.input_ids.tolist()[0]
|
||||||
|
self.assertListEqual(expected_src_tokens, result)
|
||||||
|
# Test that special tokens are reset
|
||||||
|
|
||||||
|
# Test Prepare Seq
|
||||||
|
@require_torch
|
||||||
|
def test_seq2seq_batch_empty_target_text(self):
|
||||||
|
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
||||||
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
|
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
|
||||||
|
# check if input_ids are returned and no labels
|
||||||
|
self.assertIn("input_ids", batch)
|
||||||
|
self.assertIn("attention_mask", batch)
|
||||||
|
self.assertNotIn("labels", batch)
|
||||||
|
self.assertNotIn("decoder_attention_mask", batch)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_seq2seq_batch_max_target_length(self):
|
||||||
|
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
||||||
|
tgt_text = [
|
||||||
|
"Summary of the text.",
|
||||||
|
"Another summary.",
|
||||||
|
]
|
||||||
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
|
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
|
||||||
|
)
|
||||||
|
self.assertEqual(32, batch["labels"].shape[1])
|
||||||
|
|
||||||
|
# test None max_target_length
|
||||||
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
|
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
|
||||||
|
)
|
||||||
|
self.assertEqual(32, batch["labels"].shape[1])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_seq2seq_batch_not_longer_than_maxlen(self):
|
||||||
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
|
["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt"
|
||||||
|
)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
self.assertEqual(batch.input_ids.shape, (2, 1024))
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_special_tokens(self):
|
||||||
|
|
||||||
|
src_text = ["A long paragraph for summrization."]
|
||||||
|
tgt_text = [
|
||||||
|
"Summary of the text.",
|
||||||
|
]
|
||||||
|
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||||
|
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
|
||||||
|
input_ids = batch["input_ids"]
|
||||||
|
labels = batch["labels"]
|
||||||
|
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||||
|
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
|
||||||
|
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
|
||||||
|
self.assertTrue((labels[:, -1] == tokenizer.eos_token_id).all().item())
|
||||||
@@ -1555,14 +1555,19 @@ class TokenizerTesterMixin:
|
|||||||
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||||
]
|
]
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
src_texts=src_text,
|
||||||
|
tgt_texts=tgt_text,
|
||||||
|
max_length=3,
|
||||||
|
max_target_length=10,
|
||||||
|
return_tensors="pt",
|
||||||
|
src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error
|
||||||
)
|
)
|
||||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
self.assertEqual(batch.labels.shape[1], 10)
|
||||||
# max_target_length will default to max_length if not specified
|
# max_target_length will default to max_length if not specified
|
||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
||||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
self.assertEqual(batch.labels.shape[1], 3)
|
||||||
|
|
||||||
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
|
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
|
||||||
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
|
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
|
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer, is_torch_available
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
|
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB, SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
|
||||||
EN_CODE = 250004
|
EN_CODE = 250004
|
||||||
RO_CODE = 250020
|
RO_CODE = 250020
|
||||||
|
|
||||||
@@ -123,35 +126,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||||
|
|
||||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
|
||||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
|
||||||
self.src_text,
|
|
||||||
tgt_texts=self.tgt_text,
|
|
||||||
max_length=len(self.expected_src_tokens),
|
|
||||||
)
|
|
||||||
self.assertIsInstance(batch, BatchEncoding)
|
|
||||||
|
|
||||||
self.assertEqual((2, 14), batch.input_ids.shape)
|
|
||||||
self.assertEqual((2, 14), batch.attention_mask.shape)
|
|
||||||
result = batch.input_ids.tolist()[0]
|
|
||||||
self.assertListEqual(self.expected_src_tokens, result)
|
|
||||||
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
|
||||||
# Test that special tokens are reset
|
|
||||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
|
||||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
|
||||||
|
|
||||||
def test_max_target_length(self):
|
|
||||||
|
|
||||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
|
||||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
|
||||||
)
|
|
||||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
|
||||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
|
||||||
# max_target_length will default to max_length if not specified
|
|
||||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
|
||||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
|
||||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
|
||||||
|
|
||||||
def test_enro_tokenizer_batch_encode_plus(self):
|
def test_enro_tokenizer_batch_encode_plus(self):
|
||||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
self.assertListEqual(self.expected_src_tokens, ids)
|
self.assertListEqual(self.expected_src_tokens, ids)
|
||||||
@@ -169,7 +143,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
assert isinstance(src_text[0], str)
|
assert isinstance(src_text[0], str)
|
||||||
desired_max_length = 10
|
desired_max_length = 10
|
||||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||||
src_text, return_tensors=None, max_length=desired_max_length
|
src_text,
|
||||||
|
return_tensors=None,
|
||||||
|
max_length=desired_max_length,
|
||||||
).input_ids[0]
|
).input_ids[0]
|
||||||
self.assertEqual(ids[-2], 2)
|
self.assertEqual(ids[-2], 2)
|
||||||
self.assertEqual(ids[-1], EN_CODE)
|
self.assertEqual(ids[-1], EN_CODE)
|
||||||
@@ -184,3 +160,53 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
self.tokenizer.save_pretrained(tmpdirname)
|
self.tokenizer.save_pretrained(tmpdirname)
|
||||||
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
|
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
|
||||||
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||||
|
|
||||||
|
# prepare_seq2seq_batch tests below
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_batch_fairseq_parity(self):
|
||||||
|
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, return_tensors="pt"
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
for k in batch:
|
||||||
|
batch[k] = batch[k].tolist()
|
||||||
|
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||||
|
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||||
|
# batch.decoder_inputs_ids[0][0] ==
|
||||||
|
assert batch.input_ids[1][-2:] == [2, EN_CODE]
|
||||||
|
assert batch.decoder_input_ids[1][0] == RO_CODE
|
||||||
|
assert batch.decoder_input_ids[1][-1] == 2
|
||||||
|
assert batch.labels[1][-2:] == [2, RO_CODE]
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||||
|
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text,
|
||||||
|
tgt_texts=self.tgt_text,
|
||||||
|
max_length=len(self.expected_src_tokens),
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
self.assertIsInstance(batch, BatchEncoding)
|
||||||
|
|
||||||
|
self.assertEqual((2, 14), batch.input_ids.shape)
|
||||||
|
self.assertEqual((2, 14), batch.attention_mask.shape)
|
||||||
|
result = batch.input_ids.tolist()[0]
|
||||||
|
self.assertListEqual(self.expected_src_tokens, result)
|
||||||
|
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||||
|
# Test that special tokens are reset
|
||||||
|
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||||
|
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||||
|
|
||||||
|
def test_seq2seq_max_target_length(self):
|
||||||
|
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||||
|
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||||
|
)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||||
|
# max_target_length will default to max_length if not specified
|
||||||
|
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||||
|
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||||
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
|
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
||||||
assert batch.input_ids.shape == (2, 1024)
|
assert batch.input_ids.shape == (2, 1024)
|
||||||
assert batch.attention_mask.shape == (2, 1024)
|
assert batch.attention_mask.shape == (2, 1024)
|
||||||
assert "decoder_input_ids" in batch # because tgt_texts was specified
|
assert "labels" in batch # because tgt_texts was specified
|
||||||
assert batch.decoder_input_ids.shape == (2, 5)
|
assert batch.labels.shape == (2, 5)
|
||||||
assert batch.decoder_attention_mask.shape == (2, 5)
|
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
|
||||||
assert len(batch) == 4 # no extra keys
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def get_tokenizer(self, **kwargs):
|
def get_tokenizer(self, **kwargs):
|
||||||
kwargs.update(self.special_tokens_map)
|
kwargs.update(self.special_tokens_map)
|
||||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_rust_tokenizer(self, **kwargs):
|
def get_rust_tokenizer(self, **kwargs):
|
||||||
kwargs.update(self.special_tokens_map)
|
kwargs.update(self.special_tokens_map)
|
||||||
@@ -78,7 +78,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||||
text = "lower newer"
|
text = "lower newer"
|
||||||
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
|
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
|
||||||
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
|
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
|
||||||
@@ -99,7 +99,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sequence_builders(self):
|
def test_sequence_builders(self):
|
||||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
tokenizer = self.tokenizer_class.from_pretrained("roberta-base")
|
||||||
|
|
||||||
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
||||||
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
||||||
@@ -137,7 +137,7 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
|
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
|
||||||
self.assertNotEqual(first_char, space_encoding)
|
self.assertNotEqual(first_char, space_encoding)
|
||||||
|
|
||||||
# Testing spaces after special tokenss
|
# Testing spaces after special tokens
|
||||||
mask = "<mask>"
|
mask = "<mask>"
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
|
{"mask_token": AddedToken(mask, lstrip=True, rstrip=False)}
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_max_target_length(self):
|
def test_max_target_length(self):
|
||||||
tokenizer = self.t5_base_tokenizer
|
tokenizer = self.t5_base_tokenizer
|
||||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
src_text = ["A short paragraph for summrization.", "Another short paragraph for summrization."]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"Summary of the text.",
|
"Summary of the text.",
|
||||||
"Another summary.",
|
"Another summary.",
|
||||||
@@ -161,14 +161,14 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
batch = tokenizer.prepare_seq2seq_batch(
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
|
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
|
||||||
)
|
)
|
||||||
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
|
self.assertEqual(32, batch["labels"].shape[1])
|
||||||
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
||||||
|
|
||||||
# test None max_target_length
|
# test None max_target_length
|
||||||
batch = tokenizer.prepare_seq2seq_batch(
|
batch = tokenizer.prepare_seq2seq_batch(
|
||||||
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
|
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
|
||||||
)
|
)
|
||||||
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
|
self.assertEqual(32, batch["labels"].shape[1])
|
||||||
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
||||||
|
|
||||||
def test_outputs_not_longer_than_maxlen(self):
|
def test_outputs_not_longer_than_maxlen(self):
|
||||||
@@ -190,7 +190,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
|
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
|
||||||
|
|
||||||
src_ids = list(batch.input_ids.numpy()[0])
|
src_ids = list(batch.input_ids.numpy()[0])
|
||||||
tgt_ids = list(batch.decoder_input_ids.numpy()[0])
|
tgt_ids = list(batch.labels.numpy()[0])
|
||||||
|
|
||||||
self.assertEqual(expected_src_tokens, src_ids)
|
self.assertEqual(expected_src_tokens, src_ids)
|
||||||
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user