Seq2seq trainer (#9241)
* Add label smoothing in Trainer * Add options for scheduler and Adafactor in Trainer * Put Seq2SeqTrainer in the main lib * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Address review comments and adapt scripts * Documentation * Move test not using script to tests folder Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -43,6 +43,10 @@ Schedules
|
|||||||
Learning Rate Schedules (Pytorch)
|
Learning Rate Schedules (Pytorch)
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SchedulerType
|
||||||
|
|
||||||
|
.. autofunction:: transformers.get_scheduler
|
||||||
|
|
||||||
.. autofunction:: transformers.get_constant_schedule
|
.. autofunction:: transformers.get_constant_schedule
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,13 @@ Trainer
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
Seq2SeqTrainer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Seq2SeqTrainer
|
||||||
|
:members: evaluate, predict
|
||||||
|
|
||||||
|
|
||||||
TFTrainer
|
TFTrainer
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -77,6 +84,13 @@ TrainingArguments
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
Seq2SeqTrainingArguments
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Seq2SeqTrainingArguments
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
TFTrainingArguments
|
TFTrainingArguments
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -20,9 +20,16 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from seq2seq_trainer import Seq2SeqTrainer
|
from transformers import (
|
||||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
AutoConfig,
|
||||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
HfArgumentParser,
|
||||||
|
MBartTokenizer,
|
||||||
|
Seq2SeqTrainer,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
||||||
from transformers.training_args import ParallelMode
|
from transformers.training_args import ParallelMode
|
||||||
from utils import (
|
from utils import (
|
||||||
@@ -86,14 +93,14 @@ class DataTrainingArguments:
|
|||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_target_length: Optional[int] = field(
|
max_length: Optional[int] = field(
|
||||||
default=128,
|
default=128,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
"than this will be truncated, sequences shorter will be padded."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_max_target_length: Optional[int] = field(
|
eval_max_length: Optional[int] = field(
|
||||||
default=142,
|
default=142,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||||
@@ -101,13 +108,6 @@ class DataTrainingArguments:
|
|||||||
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
|
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
test_max_target_length: Optional[int] = field(
|
|
||||||
default=142,
|
|
||||||
metadata={
|
|
||||||
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
|
|
||||||
"than this will be truncated, sequences shorter will be padded."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
|
||||||
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
|
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
|
||||||
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
|
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
|
||||||
@@ -233,7 +233,7 @@ def main():
|
|||||||
type_path="train",
|
type_path="train",
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
n_obs=data_args.n_train,
|
n_obs=data_args.n_train,
|
||||||
max_target_length=data_args.max_target_length,
|
max_target_length=data_args.max_length,
|
||||||
max_source_length=data_args.max_source_length,
|
max_source_length=data_args.max_source_length,
|
||||||
prefix=model.config.prefix or "",
|
prefix=model.config.prefix or "",
|
||||||
)
|
)
|
||||||
@@ -246,7 +246,7 @@ def main():
|
|||||||
type_path="val",
|
type_path="val",
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
n_obs=data_args.n_val,
|
n_obs=data_args.n_val,
|
||||||
max_target_length=data_args.val_max_target_length,
|
max_target_length=data_args.eval_max_length,
|
||||||
max_source_length=data_args.max_source_length,
|
max_source_length=data_args.max_source_length,
|
||||||
prefix=model.config.prefix or "",
|
prefix=model.config.prefix or "",
|
||||||
)
|
)
|
||||||
@@ -259,7 +259,7 @@ def main():
|
|||||||
type_path="test",
|
type_path="test",
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
n_obs=data_args.n_test,
|
n_obs=data_args.n_test,
|
||||||
max_target_length=data_args.test_max_target_length,
|
max_target_length=data_args.eval_max_length,
|
||||||
max_source_length=data_args.max_source_length,
|
max_source_length=data_args.max_source_length,
|
||||||
prefix=model.config.prefix or "",
|
prefix=model.config.prefix or "",
|
||||||
)
|
)
|
||||||
@@ -273,13 +273,12 @@ def main():
|
|||||||
)
|
)
|
||||||
trainer = Seq2SeqTrainer(
|
trainer = Seq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
config=config,
|
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||||
compute_metrics=compute_metrics_fn,
|
compute_metrics=compute_metrics_fn,
|
||||||
data_args=data_args,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_metrics = {}
|
all_metrics = {}
|
||||||
@@ -310,7 +309,9 @@ def main():
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
metrics = trainer.evaluate(metric_key_prefix="val")
|
metrics = trainer.evaluate(
|
||||||
|
metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams
|
||||||
|
)
|
||||||
metrics["val_n_objs"] = data_args.n_val
|
metrics["val_n_objs"] = data_args.n_val
|
||||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||||
|
|
||||||
@@ -322,7 +323,12 @@ def main():
|
|||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
|
test_output = trainer.predict(
|
||||||
|
test_dataset=test_dataset,
|
||||||
|
metric_key_prefix="test",
|
||||||
|
max_length=data_args.eval_max_length,
|
||||||
|
num_beams=data_args.eval_beams,
|
||||||
|
)
|
||||||
metrics = test_output.metrics
|
metrics = test_output.metrics
|
||||||
metrics["test_n_objs"] = data_args.n_test
|
metrics["test_n_objs"] = data_args.n_test
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ import sys
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel
|
from transformers.file_utils import is_apex_available
|
||||||
from transformers.file_utils import is_apex_available, is_datasets_available
|
|
||||||
from transformers.integrations import is_fairscale_available
|
from transformers.integrations import is_fairscale_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
@@ -31,8 +30,7 @@ from transformers.testing_utils import (
|
|||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
from .finetune_trainer import main
|
||||||
from .seq2seq_trainer import Seq2SeqTrainer
|
|
||||||
|
|
||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@@ -120,119 +118,6 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
assert "test_generations.txt" in contents
|
assert "test_generations.txt" in contents
|
||||||
assert "test_results.json" in contents
|
assert "test_results.json" in contents
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_finetune_bert2bert(self):
|
|
||||||
if not is_datasets_available():
|
|
||||||
return
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
||||||
|
|
||||||
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
|
||||||
bert2bert.config.eos_token_id = tokenizer.sep_token_id
|
|
||||||
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
|
||||||
bert2bert.config.max_length = 128
|
|
||||||
|
|
||||||
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
|
||||||
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
|
||||||
|
|
||||||
train_dataset = train_dataset.select(range(32))
|
|
||||||
val_dataset = val_dataset.select(range(16))
|
|
||||||
|
|
||||||
rouge = datasets.load_metric("rouge")
|
|
||||||
|
|
||||||
batch_size = 4
|
|
||||||
|
|
||||||
def _map_to_encoder_decoder_inputs(batch):
|
|
||||||
# Tokenizer will automatically set [BOS] <text> [EOS]
|
|
||||||
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
|
||||||
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
|
||||||
batch["input_ids"] = inputs.input_ids
|
|
||||||
batch["attention_mask"] = inputs.attention_mask
|
|
||||||
|
|
||||||
batch["decoder_input_ids"] = outputs.input_ids
|
|
||||||
batch["labels"] = outputs.input_ids.copy()
|
|
||||||
batch["labels"] = [
|
|
||||||
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
|
||||||
]
|
|
||||||
batch["decoder_attention_mask"] = outputs.attention_mask
|
|
||||||
|
|
||||||
assert all([len(x) == 512 for x in inputs.input_ids])
|
|
||||||
assert all([len(x) == 128 for x in outputs.input_ids])
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def _compute_metrics(pred):
|
|
||||||
labels_ids = pred.label_ids
|
|
||||||
pred_ids = pred.predictions
|
|
||||||
|
|
||||||
# all unnecessary tokens are removed
|
|
||||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
|
||||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
|
||||||
|
|
||||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
|
||||||
"rouge2"
|
|
||||||
].mid
|
|
||||||
|
|
||||||
return {
|
|
||||||
"rouge2_precision": round(rouge_output.precision, 4),
|
|
||||||
"rouge2_recall": round(rouge_output.recall, 4),
|
|
||||||
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
|
||||||
}
|
|
||||||
|
|
||||||
# map train dataset
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
_map_to_encoder_decoder_inputs,
|
|
||||||
batched=True,
|
|
||||||
batch_size=batch_size,
|
|
||||||
remove_columns=["article", "highlights"],
|
|
||||||
)
|
|
||||||
train_dataset.set_format(
|
|
||||||
type="torch",
|
|
||||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# same for validation dataset
|
|
||||||
val_dataset = val_dataset.map(
|
|
||||||
_map_to_encoder_decoder_inputs,
|
|
||||||
batched=True,
|
|
||||||
batch_size=batch_size,
|
|
||||||
remove_columns=["article", "highlights"],
|
|
||||||
)
|
|
||||||
val_dataset.set_format(
|
|
||||||
type="torch",
|
|
||||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
|
||||||
)
|
|
||||||
|
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
|
||||||
|
|
||||||
training_args = Seq2SeqTrainingArguments(
|
|
||||||
output_dir=output_dir,
|
|
||||||
per_device_train_batch_size=batch_size,
|
|
||||||
per_device_eval_batch_size=batch_size,
|
|
||||||
predict_with_generate=True,
|
|
||||||
evaluation_strategy="steps",
|
|
||||||
do_train=True,
|
|
||||||
do_eval=True,
|
|
||||||
warmup_steps=0,
|
|
||||||
eval_steps=2,
|
|
||||||
logging_steps=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# instantiate trainer
|
|
||||||
trainer = Seq2SeqTrainer(
|
|
||||||
model=bert2bert,
|
|
||||||
args=training_args,
|
|
||||||
compute_metrics=_compute_metrics,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=val_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# start training
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
def run_trainer(
|
def run_trainer(
|
||||||
self,
|
self,
|
||||||
eval_steps: int,
|
eval_steps: int,
|
||||||
@@ -252,8 +137,8 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
--n_train 8
|
--n_train 8
|
||||||
--n_val 8
|
--n_val 8
|
||||||
--max_source_length {max_len}
|
--max_source_length {max_len}
|
||||||
--max_target_length {max_len}
|
--max_length {max_len}
|
||||||
--val_max_target_length {max_len}
|
--eval_max_length {max_len}
|
||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--do_predict
|
--do_predict
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ python finetune_trainer.py \
|
|||||||
--freeze_encoder --freeze_embeds \
|
--freeze_encoder --freeze_embeds \
|
||||||
--num_train_epochs=6 \
|
--num_train_epochs=6 \
|
||||||
--save_steps 3000 --eval_steps 3000 \
|
--save_steps 3000 --eval_steps 3000 \
|
||||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
|
||||||
--do_train --do_eval --do_predict \
|
--do_train --do_eval --do_predict \
|
||||||
--evaluation_strategy steps \
|
--evaluation_strategy steps \
|
||||||
--predict_with_generate --logging_first_step \
|
--predict_with_generate --logging_first_step \
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
|||||||
--num_train_epochs=6 \
|
--num_train_epochs=6 \
|
||||||
--save_steps 500 --eval_steps 500 \
|
--save_steps 500 --eval_steps 500 \
|
||||||
--logging_first_step --logging_steps 200 \
|
--logging_first_step --logging_steps 200 \
|
||||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
|
||||||
--do_train --do_eval \
|
--do_train --do_eval \
|
||||||
--evaluation_strategy steps \
|
--evaluation_strategy steps \
|
||||||
--prediction_loss_only \
|
--prediction_loss_only \
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ python finetune_trainer.py \
|
|||||||
--num_train_epochs=2 \
|
--num_train_epochs=2 \
|
||||||
--save_steps 3000 --eval_steps 3000 \
|
--save_steps 3000 --eval_steps 3000 \
|
||||||
--logging_first_step \
|
--logging_first_step \
|
||||||
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
--max_length 56 --eval_max_length $MAX_TGT_LEN \
|
||||||
--do_train --do_eval --do_predict \
|
--do_train --do_eval --do_predict \
|
||||||
--evaluation_strategy steps \
|
--evaluation_strategy steps \
|
||||||
--predict_with_generate --sortish_sampler \
|
--predict_with_generate --sortish_sampler \
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ python finetune_trainer.py \
|
|||||||
--src_lang en_XX --tgt_lang ro_RO \
|
--src_lang en_XX --tgt_lang ro_RO \
|
||||||
--freeze_embeds \
|
--freeze_embeds \
|
||||||
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||||
--max_source_length 128 --max_target_length 128 \
|
--max_source_length 128 --max_length 128 --eval_max_length 128 \
|
||||||
--val_max_target_length 128 --test_max_target_length 128 \
|
|
||||||
--sortish_sampler \
|
--sortish_sampler \
|
||||||
--num_train_epochs 6 \
|
--num_train_epochs 6 \
|
||||||
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ class Seq2SeqDataCollator:
|
|||||||
[x["src_texts"] for x in batch],
|
[x["src_texts"] for x in batch],
|
||||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
max_length=self.data_args.max_source_length,
|
max_length=self.data_args.max_source_length,
|
||||||
max_target_length=self.data_args.max_target_length,
|
max_target_length=self.data_args.max_length,
|
||||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**self.dataset_kwargs,
|
**self.dataset_kwargs,
|
||||||
|
|||||||
@@ -287,8 +287,9 @@ from .trainer_callback import (
|
|||||||
TrainerControl,
|
TrainerControl,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
|
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
|
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -682,11 +683,13 @@ if is_torch_available():
|
|||||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
|
get_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from .trainer_pt_utils import torch_distributed_zero_first
|
from .trainer_pt_utils import torch_distributed_zero_first
|
||||||
|
from .trainer_seq2seq import Seq2SeqTrainer
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_pt_objects import *
|
from .utils.dummy_pt_objects import *
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,13 @@
|
|||||||
"""PyTorch optimization for BERT model."""
|
"""PyTorch optimization for BERT model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Iterable, Tuple
|
from typing import Callable, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from .trainer_utils import SchedulerType
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -215,6 +216,56 @@ def get_polynomial_decay_schedule_with_warmup(
|
|||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
TYPE_TO_SCHEDULER_FUNCTION = {
|
||||||
|
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
||||||
|
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
||||||
|
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
||||||
|
SchedulerType.CONSTANT: get_constant_schedule,
|
||||||
|
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(
|
||||||
|
name: Union[str, SchedulerType],
|
||||||
|
optimizer: Optimizer,
|
||||||
|
num_warmup_steps: Optional[int] = None,
|
||||||
|
num_training_steps: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Unified API to get any scheduler from its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (:obj:`str` or `:obj:`SchedulerType`):
|
||||||
|
The name of the scheduler to use.
|
||||||
|
optimizer (:obj:`torch.optim.Optimizer`):
|
||||||
|
The optimizer that will be used during training.
|
||||||
|
num_warmup_steps (:obj:`int`, `optional`):
|
||||||
|
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
num_training_steps (:obj:`int`, `optional`):
|
||||||
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||||
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
|
"""
|
||||||
|
name = SchedulerType(name)
|
||||||
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
if name == SchedulerType.CONSTANT:
|
||||||
|
return schedule_func(optimizer)
|
||||||
|
|
||||||
|
# All other schedulers require `num_warmup_steps`
|
||||||
|
if num_warmup_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||||
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||||
|
|
||||||
|
# All other schedulers require `num_training_steps`
|
||||||
|
if num_training_steps is None:
|
||||||
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||||
|
|
||||||
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||||
|
|
||||||
|
|
||||||
class AdamW(Optimizer):
|
class AdamW(Optimizer):
|
||||||
"""
|
"""
|
||||||
Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization
|
Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
|
|||||||
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
from .optimization import AdamW, get_linear_schedule_with_warmup
|
from .optimization import Adafactor, AdamW, get_scheduler
|
||||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
@@ -70,6 +70,7 @@ from .trainer_callback import (
|
|||||||
)
|
)
|
||||||
from .trainer_pt_utils import (
|
from .trainer_pt_utils import (
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
|
LabelSmoother,
|
||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
@@ -320,6 +321,12 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
self.use_apex = True
|
self.use_apex = True
|
||||||
|
|
||||||
|
# Label smoothing
|
||||||
|
if self.args.label_smoothing_factor != 0:
|
||||||
|
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
||||||
|
else:
|
||||||
|
self.label_smoother = None
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
|
||||||
@@ -507,24 +514,32 @@ class Trainer:
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
optimizer_cls = Adafactor if self.args.adafactor else AdamW
|
||||||
|
if self.args.adafactor:
|
||||||
|
optimizer_cls = Adafactor
|
||||||
|
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
|
||||||
|
else:
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs = {
|
||||||
|
"betas": (self.args.adam_beta1, self.args.adam_beta2),
|
||||||
|
"eps": self.args.adam_epsilon,
|
||||||
|
}
|
||||||
|
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||||
if self.sharded_dpp:
|
if self.sharded_dpp:
|
||||||
self.optimizer = OSS(
|
self.optimizer = OSS(
|
||||||
params=optimizer_grouped_parameters,
|
params=optimizer_grouped_parameters,
|
||||||
optim=AdamW,
|
optim=optimizer_cls,
|
||||||
lr=self.args.learning_rate,
|
**optimizer_kwargs,
|
||||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
|
||||||
eps=self.args.adam_epsilon,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.optimizer = AdamW(
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
optimizer_grouped_parameters,
|
|
||||||
lr=self.args.learning_rate,
|
|
||||||
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
|
||||||
eps=self.args.adam_epsilon,
|
|
||||||
)
|
|
||||||
if self.lr_scheduler is None:
|
if self.lr_scheduler is None:
|
||||||
self.lr_scheduler = get_linear_schedule_with_warmup(
|
self.lr_scheduler = get_scheduler(
|
||||||
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
self.args.lr_scheduler_type,
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps=self.args.warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def num_examples(self, dataloader: DataLoader) -> int:
|
def num_examples(self, dataloader: DataLoader) -> int:
|
||||||
@@ -1168,8 +1183,12 @@ class Trainer:
|
|||||||
# TODO: this needs to be fixed and made cleaner later.
|
# TODO: this needs to be fixed and made cleaner later.
|
||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = outputs[self.args.past_index]
|
self._past = outputs[self.args.past_index]
|
||||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
|
||||||
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
if self.label_smoother is not None and "labels" in inputs:
|
||||||
|
return self.label_smoother(outputs, inputs["labels"])
|
||||||
|
else:
|
||||||
|
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||||
|
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||||
|
|
||||||
def is_local_process_zero(self) -> bool:
|
def is_local_process_zero(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -1556,11 +1575,13 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
if has_labels:
|
if has_labels:
|
||||||
|
if self.label_smoother is not None and "labels" in inputs:
|
||||||
|
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
||||||
if isinstance(outputs, dict):
|
if isinstance(outputs, dict):
|
||||||
loss = outputs["loss"].mean().detach()
|
|
||||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||||
else:
|
else:
|
||||||
loss = outputs[0].mean().detach()
|
|
||||||
logits = outputs[1:]
|
logits = outputs[1:]
|
||||||
else:
|
else:
|
||||||
loss = None
|
loss = None
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ Torch utilities for the Trainer class.
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -360,3 +361,32 @@ class DistributedTensorGatherer:
|
|||||||
if self._offsets[0] != self.process_length:
|
if self._offsets[0] != self.process_length:
|
||||||
logger.warn("Not all data has been set. Are you sure you passed all values?")
|
logger.warn("Not all data has been set. Are you sure you passed all values?")
|
||||||
return nested_truncate(self._storage, self.num_samples)
|
return nested_truncate(self._storage, self.num_samples)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LabelSmoother:
|
||||||
|
"""
|
||||||
|
Adds label-smoothing on a pre-computed output from a Transformers model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epsilon (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
|
The label smoothing factor.
|
||||||
|
ignore_index (:obj:`int`, `optional`, defaults to -100):
|
||||||
|
The index in the labels to ignore when computing the loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
epsilon: float = 0.1
|
||||||
|
ignore_index: int = -100
|
||||||
|
|
||||||
|
def __call__(self, model_output, labels):
|
||||||
|
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0]
|
||||||
|
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
|
||||||
|
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
# Look at the ignored index and mask the corresponding log_probs.
|
||||||
|
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index)
|
||||||
|
log_probs.masked_fill_(padding_mask, 0.0)
|
||||||
|
|
||||||
|
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||||
|
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum())
|
||||||
|
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss
|
||||||
|
|||||||
231
src/transformers/trainer_seq2seq.py
Normal file
231
src/transformers/trainer_seq2seq.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DistributedSampler, RandomSampler
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
from .file_utils import is_torch_tpu_available
|
||||||
|
from .trainer import Trainer
|
||||||
|
from .trainer_pt_utils import get_tpu_sampler
|
||||||
|
from .trainer_utils import PredictionOutput
|
||||||
|
from .training_args import ParallelMode
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqTrainer(Trainer):
|
||||||
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
|
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||||
|
return None
|
||||||
|
elif is_torch_tpu_available():
|
||||||
|
return get_tpu_sampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
if self.args.sortish_sampler:
|
||||||
|
self.train_dataset.make_sortish_sampler(
|
||||||
|
self.args.per_device_train_batch_size,
|
||||||
|
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
RandomSampler(self.train_dataset)
|
||||||
|
if self.args.local_rank == -1
|
||||||
|
else DistributedSampler(self.train_dataset)
|
||||||
|
)
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
eval_dataset: Optional[Dataset] = None,
|
||||||
|
ignore_keys: Optional[List[str]] = None,
|
||||||
|
metric_key_prefix: str = "eval",
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
num_beams: Optional[int] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Run evaluation and returns metrics.
|
||||||
|
|
||||||
|
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
|
||||||
|
(pass it to the init :obj:`compute_metrics` argument).
|
||||||
|
|
||||||
|
You can also subclass and override this method to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset (:obj:`Dataset`, `optional`):
|
||||||
|
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
||||||
|
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
|
||||||
|
:obj:`__len__` method.
|
||||||
|
ignore_keys (:obj:`List[str]`, `optional`):
|
||||||
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||||
|
gathering predictions.
|
||||||
|
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
|
||||||
|
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||||
|
"eval_bleu" if the prefix is ``"eval"`` (default)
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
The maximum target length to use when predicting with the generate method.
|
||||||
|
num_beams (:obj:`int`, `optional`):
|
||||||
|
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
||||||
|
beam search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||||
|
dictionary also contains the epoch number which comes from the training state.
|
||||||
|
"""
|
||||||
|
self._max_length = max_length
|
||||||
|
self._num_beams = num_beams
|
||||||
|
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
test_dataset: Dataset,
|
||||||
|
ignore_keys: Optional[List[str]] = None,
|
||||||
|
metric_key_prefix: str = "eval",
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
num_beams: Optional[int] = None,
|
||||||
|
) -> PredictionOutput:
|
||||||
|
"""
|
||||||
|
Run prediction and returns predictions and potential metrics.
|
||||||
|
|
||||||
|
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
|
||||||
|
will also return metrics, like in :obj:`evaluate()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
test_dataset (:obj:`Dataset`):
|
||||||
|
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
|
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
||||||
|
ignore_keys (:obj:`List[str]`, `optional`):
|
||||||
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||||
|
gathering predictions.
|
||||||
|
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
|
||||||
|
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||||
|
"eval_bleu" if the prefix is ``"eval"`` (default)
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
The maximum target length to use when predicting with the generate method.
|
||||||
|
num_beams (:obj:`int`, `optional`):
|
||||||
|
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
||||||
|
beam search.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
|
||||||
|
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
||||||
|
concatenation into one array. The padding index is -100.
|
||||||
|
|
||||||
|
Returns: `NamedTuple` A namedtuple with the following keys:
|
||||||
|
|
||||||
|
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
|
||||||
|
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
|
||||||
|
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
||||||
|
contained labels).
|
||||||
|
"""
|
||||||
|
self._max_length = max_length
|
||||||
|
self._num_beams = num_beams
|
||||||
|
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
prediction_loss_only: bool,
|
||||||
|
ignore_keys: Optional[List[str]] = None,
|
||||||
|
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to evaluate.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
prediction_loss_only (:obj:`bool`):
|
||||||
|
Whether or not to return the loss only.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||||
|
labels (each being optional).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.args.predict_with_generate or prediction_loss_only:
|
||||||
|
return super()(self, model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
has_labels = "labels" in inputs
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
|
||||||
|
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
||||||
|
}
|
||||||
|
|
||||||
|
generated_tokens = self.model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
|
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||||
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.use_amp:
|
||||||
|
with autocast():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
else:
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if has_labels:
|
||||||
|
if self.label_smoother is not None:
|
||||||
|
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if self.args.prediction_loss_only:
|
||||||
|
return (loss, None, None)
|
||||||
|
|
||||||
|
labels = inputs["labels"]
|
||||||
|
if labels.shape[-1] < gen_kwargs["max_length"]:
|
||||||
|
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
|
def _pad_tensors_to_max_len(self, tensor, max_length):
|
||||||
|
if self.tokenizer is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tensor need to be padded to `max_length={max_length}` but no tokenzier was passed when creating "
|
||||||
|
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
|
||||||
|
)
|
||||||
|
# If PAD token is not defined at least EOS token has to be defined
|
||||||
|
pad_token_id = (
|
||||||
|
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
padded_tensor = pad_token_id * torch.ones(
|
||||||
|
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
||||||
|
)
|
||||||
|
padded_tensor[:, : tensor.shape[-1]] = tensor
|
||||||
|
return padded_tensor
|
||||||
@@ -201,3 +201,12 @@ def speed_metrics(split, start_time, num_samples=None):
|
|||||||
samples_per_second = 1 / (runtime / num_samples)
|
samples_per_second = 1 / (runtime / num_samples)
|
||||||
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerType(ExplicitEnum):
|
||||||
|
LINEAR = "linear"
|
||||||
|
COSINE = "cosine"
|
||||||
|
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
||||||
|
POLYNOMIAL = "polynomial"
|
||||||
|
CONSTANT = "constant"
|
||||||
|
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from enum import Enum
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||||
from .trainer_utils import EvaluationStrategy
|
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -121,6 +121,9 @@ class TrainingArguments:
|
|||||||
max_steps (:obj:`int`, `optional`, defaults to -1):
|
max_steps (:obj:`int`, `optional`, defaults to -1):
|
||||||
If set to a positive number, the total number of training steps to perform. Overrides
|
If set to a positive number, the total number of training steps to perform. Overrides
|
||||||
:obj:`num_train_epochs`.
|
:obj:`num_train_epochs`.
|
||||||
|
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
||||||
|
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
||||||
|
values.
|
||||||
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
@@ -217,6 +220,13 @@ class TrainingArguments:
|
|||||||
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
|
||||||
training only). This is an experimental feature.
|
training only). This is an experimental feature.
|
||||||
|
label_smoothing_factor (:obj:`float`, `optional`, defaults to 0.0):
|
||||||
|
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
|
||||||
|
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
|
||||||
|
label_smoothing_factor + label_smoothing_factor/num_labels` respectively.
|
||||||
|
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
|
||||||
|
:class:`~transformers.AdamW`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -246,7 +256,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
evaluation_strategy: EvaluationStrategy = field(
|
evaluation_strategy: EvaluationStrategy = field(
|
||||||
default="no",
|
default="no",
|
||||||
metadata={"help": "Run evaluation during training at each logging step."},
|
metadata={"help": "The evaluation strategy to use."},
|
||||||
)
|
)
|
||||||
prediction_loss_only: bool = field(
|
prediction_loss_only: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -296,6 +306,10 @@ class TrainingArguments:
|
|||||||
default=-1,
|
default=-1,
|
||||||
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
|
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
|
||||||
)
|
)
|
||||||
|
lr_scheduler_type: SchedulerType = field(
|
||||||
|
default="linear",
|
||||||
|
metadata={"help": "The scheduler type to use."},
|
||||||
|
)
|
||||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
|
||||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||||
@@ -392,11 +406,16 @@ class TrainingArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
|
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
|
||||||
)
|
)
|
||||||
|
label_smoothing_factor: float = field(
|
||||||
|
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||||
|
)
|
||||||
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.disable_tqdm is None:
|
if self.disable_tqdm is None:
|
||||||
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
|
||||||
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
|
||||||
|
self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
|
||||||
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO:
|
||||||
self.do_eval = True
|
self.do_eval = True
|
||||||
if self.eval_steps is None:
|
if self.eval_steps is None:
|
||||||
|
|||||||
42
src/transformers/training_args_seq2seq.py
Normal file
42
src/transformers/training_args_seq2seq.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@add_start_docstrings(TrainingArguments.__doc__)
|
||||||
|
class Seq2SeqTrainingArguments(TrainingArguments):
|
||||||
|
"""
|
||||||
|
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to use a `sortish sampler` or not. Only possible if the underlying datasets are `Seq2SeqDataset` for
|
||||||
|
now but will become generally available in the near future.
|
||||||
|
|
||||||
|
It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness for
|
||||||
|
the training set.
|
||||||
|
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
||||||
|
"""
|
||||||
|
|
||||||
|
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
|
||||||
|
predict_with_generate: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||||
|
)
|
||||||
@@ -2279,6 +2279,10 @@ def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
|
|||||||
requires_pytorch(get_polynomial_decay_schedule_with_warmup)
|
requires_pytorch(get_polynomial_decay_schedule_with_warmup)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(*args, **kwargs):
|
||||||
|
requires_pytorch(get_scheduler)
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
@@ -2286,3 +2290,8 @@ class Trainer:
|
|||||||
|
|
||||||
def torch_distributed_zero_first(*args, **kwargs):
|
def torch_distributed_zero_first(*args, **kwargs):
|
||||||
requires_pytorch(torch_distributed_zero_first)
|
requires_pytorch(torch_distributed_zero_first)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqTrainer:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|||||||
135
tests/test_trainer_seq2seq.py
Normal file
135
tests/test_trainer_seq2seq.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 the HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||||
|
from transformers.file_utils import is_datasets_available
|
||||||
|
from transformers.testing_utils import TestCasePlus, require_datasets, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_datasets_available():
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2seqTrainerTester(TestCasePlus):
|
||||||
|
@slow
|
||||||
|
@require_datasets
|
||||||
|
def test_finetune_bert2bert(self):
|
||||||
|
|
||||||
|
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
||||||
|
bert2bert.config.eos_token_id = tokenizer.sep_token_id
|
||||||
|
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||||
|
bert2bert.config.max_length = 128
|
||||||
|
|
||||||
|
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
||||||
|
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
||||||
|
|
||||||
|
train_dataset = train_dataset.select(range(32))
|
||||||
|
val_dataset = val_dataset.select(range(16))
|
||||||
|
|
||||||
|
rouge = datasets.load_metric("rouge")
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
def _map_to_encoder_decoder_inputs(batch):
|
||||||
|
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||||
|
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
||||||
|
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
||||||
|
batch["input_ids"] = inputs.input_ids
|
||||||
|
batch["attention_mask"] = inputs.attention_mask
|
||||||
|
|
||||||
|
batch["decoder_input_ids"] = outputs.input_ids
|
||||||
|
batch["labels"] = outputs.input_ids.copy()
|
||||||
|
batch["labels"] = [
|
||||||
|
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
||||||
|
]
|
||||||
|
batch["decoder_attention_mask"] = outputs.attention_mask
|
||||||
|
|
||||||
|
assert all([len(x) == 512 for x in inputs.input_ids])
|
||||||
|
assert all([len(x) == 128 for x in outputs.input_ids])
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _compute_metrics(pred):
|
||||||
|
labels_ids = pred.label_ids
|
||||||
|
pred_ids = pred.predictions
|
||||||
|
|
||||||
|
# all unnecessary tokens are removed
|
||||||
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||||
|
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
||||||
|
"rouge2"
|
||||||
|
].mid
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rouge2_precision": round(rouge_output.precision, 4),
|
||||||
|
"rouge2_recall": round(rouge_output.recall, 4),
|
||||||
|
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
# map train dataset
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
_map_to_encoder_decoder_inputs,
|
||||||
|
batched=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
remove_columns=["article", "highlights"],
|
||||||
|
)
|
||||||
|
train_dataset.set_format(
|
||||||
|
type="torch",
|
||||||
|
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# same for validation dataset
|
||||||
|
val_dataset = val_dataset.map(
|
||||||
|
_map_to_encoder_decoder_inputs,
|
||||||
|
batched=True,
|
||||||
|
batch_size=batch_size,
|
||||||
|
remove_columns=["article", "highlights"],
|
||||||
|
)
|
||||||
|
val_dataset.set_format(
|
||||||
|
type="torch",
|
||||||
|
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||||
|
)
|
||||||
|
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
per_device_eval_batch_size=batch_size,
|
||||||
|
predict_with_generate=True,
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
do_train=True,
|
||||||
|
do_eval=True,
|
||||||
|
warmup_steps=0,
|
||||||
|
eval_steps=2,
|
||||||
|
logging_steps=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# instantiate trainer
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=bert2bert,
|
||||||
|
args=training_args,
|
||||||
|
compute_metrics=_compute_metrics,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# start training
|
||||||
|
trainer.train()
|
||||||
@@ -22,7 +22,10 @@ from transformers.testing_utils import require_torch
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.trainer_pt_utils import DistributedTensorGatherer
|
import torch
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
|
from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -56,3 +59,31 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
self.assertTrue(np.array_equal(result[0], predictions))
|
self.assertTrue(np.array_equal(result[0], predictions))
|
||||||
self.assertTrue(np.array_equal(result[1][0], predictions))
|
self.assertTrue(np.array_equal(result[1][0], predictions))
|
||||||
self.assertTrue(np.array_equal(result[1][1], predictions))
|
self.assertTrue(np.array_equal(result[1][1], predictions))
|
||||||
|
|
||||||
|
def test_label_smoothing(self):
|
||||||
|
epsilon = 0.1
|
||||||
|
num_labels = 12
|
||||||
|
random_logits = torch.randn(4, 5, num_labels)
|
||||||
|
random_labels = torch.randint(0, num_labels, (4, 5))
|
||||||
|
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
|
||||||
|
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
|
||||||
|
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
|
||||||
|
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
|
||||||
|
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean()
|
||||||
|
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
|
||||||
|
|
||||||
|
# With a few -100 labels
|
||||||
|
random_labels[0, 1] = -100
|
||||||
|
random_labels[2, 1] = -100
|
||||||
|
random_labels[2, 3] = -100
|
||||||
|
|
||||||
|
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
|
||||||
|
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits)
|
||||||
|
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
|
||||||
|
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
|
||||||
|
# Mask the log probs with the -100 labels
|
||||||
|
log_probs[0, 1] = 0.0
|
||||||
|
log_probs[2, 1] = 0.0
|
||||||
|
log_probs[2, 3] = 0.0
|
||||||
|
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
|
||||||
|
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
|
||||||
|
|||||||
Reference in New Issue
Block a user