Have seq2seq just use gather (#27025)
* Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -3208,13 +3208,13 @@ class Trainer:
|
|||||||
|
|
||||||
# Update containers on host
|
# Update containers on host
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
|
losses = self.gather_function((loss.repeat(batch_size)))
|
||||||
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||||
if inputs_decode is not None:
|
if inputs_decode is not None:
|
||||||
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
||||||
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
|
inputs_decode = self.gather_function((inputs_decode))
|
||||||
inputs_host = (
|
inputs_host = (
|
||||||
inputs_decode
|
inputs_decode
|
||||||
if inputs_host is None
|
if inputs_host is None
|
||||||
@@ -3224,11 +3224,11 @@ class Trainer:
|
|||||||
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
||||||
if self.preprocess_logits_for_metrics is not None:
|
if self.preprocess_logits_for_metrics is not None:
|
||||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||||
logits = self.accelerator.gather_for_metrics((logits))
|
logits = self.gather_function((logits))
|
||||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = self.accelerator.gather_for_metrics((labels))
|
labels = self.gather_function((labels))
|
||||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||||
|
|
||||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||||
@@ -3261,6 +3261,8 @@ class Trainer:
|
|||||||
# Set back to None to begin a new accumulation
|
# Set back to None to begin a new accumulation
|
||||||
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
|
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
|
||||||
|
|
||||||
|
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
|
||||||
|
self.gather_function = self.accelerator.gather_for_metrics
|
||||||
if args.past_index and hasattr(self, "_past"):
|
if args.past_index and hasattr(self, "_past"):
|
||||||
# Clean the state at the end of the evaluation loop
|
# Clean the state at the end of the evaluation loop
|
||||||
delattr(self, "_past")
|
delattr(self, "_past")
|
||||||
@@ -3930,6 +3932,8 @@ class Trainer:
|
|||||||
deepspeed_plugin=self.args.deepspeed_plugin,
|
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||||
)
|
)
|
||||||
|
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||||
|
self.gather_function = self.accelerator.gather_for_metrics
|
||||||
|
|
||||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||||
|
|||||||
@@ -160,8 +160,9 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||||
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||||
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||||
|
# We don't want to drop samples in general
|
||||||
|
self.gather_function = self.accelerator.gather
|
||||||
self._gen_kwargs = gen_kwargs
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@@ -223,6 +224,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
gen_kwargs["max_length"] = self.args.generation_max_length
|
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||||
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||||
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||||
|
self.gather_function = self.accelerator.gather
|
||||||
self._gen_kwargs = gen_kwargs
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|||||||
@@ -12,8 +12,16 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from transformers import (
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
AutoModelForSeq2SeqLM,
|
||||||
|
BertTokenizer,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
EncoderDecoderModel,
|
||||||
|
GenerationConfig,
|
||||||
|
Seq2SeqTrainer,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
T5Tokenizer,
|
||||||
|
)
|
||||||
from transformers.testing_utils import TestCasePlus, require_torch, slow
|
from transformers.testing_utils import TestCasePlus, require_torch, slow
|
||||||
from transformers.utils import is_datasets_available
|
from transformers.utils import is_datasets_available
|
||||||
|
|
||||||
@@ -124,3 +132,52 @@ class Seq2seqTrainerTester(TestCasePlus):
|
|||||||
|
|
||||||
# start training
|
# start training
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_return_sequences(self):
|
||||||
|
# Tests that the number of generated sequences is correct when num_return_sequences > 1
|
||||||
|
# and essentially ensuring that `accelerator.gather()` is used instead of `gather_for_metrics`
|
||||||
|
INPUT_COLUMN = "question"
|
||||||
|
TARGET_COLUMN = "answer"
|
||||||
|
MAX_INPUT_LENGTH = 256
|
||||||
|
MAX_TARGET_LENGTH = 256
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset("gsm8k", "main", split="train[:38]")
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
|
||||||
|
gen_config = GenerationConfig.from_pretrained(
|
||||||
|
"t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True)
|
||||||
|
|
||||||
|
trainer = Seq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
compute_metrics=lambda x: {"samples": x[0].shape[0]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_data(examples):
|
||||||
|
# Remove pairs where at least one record is none
|
||||||
|
inputs = examples[INPUT_COLUMN]
|
||||||
|
targets = examples[TARGET_COLUMN]
|
||||||
|
|
||||||
|
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
|
||||||
|
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True)
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
prepared_dataset = dataset.map(prepare_data, batched=True, remove_columns=[INPUT_COLUMN, TARGET_COLUMN])
|
||||||
|
dataset_len = len(prepared_dataset) # 38
|
||||||
|
|
||||||
|
for num_return_sequences in range(3, 0, -1):
|
||||||
|
gen_config.num_return_sequences = num_return_sequences
|
||||||
|
metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config)
|
||||||
|
assert (
|
||||||
|
metrics["eval_samples"] == dataset_len * num_return_sequences
|
||||||
|
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"
|
||||||
|
|||||||
Reference in New Issue
Block a user