Have seq2seq just use gather (#27025)

* Have seq2seq just use gather

* Change

* Reset after

* Make slow

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Clean

* Simplify and just use gather

* Update tests/trainer/test_trainer_seq2seq.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* gather always for seq2seq

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Zach Mueller
2023-11-14 14:54:44 -05:00
committed by GitHub
parent 250032e974
commit 067c4a310d
3 changed files with 70 additions and 7 deletions

View File

@@ -12,8 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import (
AutoModelForSeq2SeqLM,
BertTokenizer,
DataCollatorForSeq2Seq,
EncoderDecoderModel,
GenerationConfig,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
T5Tokenizer,
)
from transformers.testing_utils import TestCasePlus, require_torch, slow
from transformers.utils import is_datasets_available
@@ -124,3 +132,52 @@ class Seq2seqTrainerTester(TestCasePlus):
# start training
trainer.train()
@slow
@require_torch
def test_return_sequences(self):
# Tests that the number of generated sequences is correct when num_return_sequences > 1
# and essentially ensuring that `accelerator.gather()` is used instead of `gather_for_metrics`
INPUT_COLUMN = "question"
TARGET_COLUMN = "answer"
MAX_INPUT_LENGTH = 256
MAX_TARGET_LENGTH = 256
dataset = datasets.load_dataset("gsm8k", "main", split="train[:38]")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
gen_config = GenerationConfig.from_pretrained(
"t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5
)
training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=lambda x: {"samples": x[0].shape[0]},
)
def prepare_data(examples):
# Remove pairs where at least one record is none
inputs = examples[INPUT_COLUMN]
targets = examples[TARGET_COLUMN]
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
prepared_dataset = dataset.map(prepare_data, batched=True, remove_columns=[INPUT_COLUMN, TARGET_COLUMN])
dataset_len = len(prepared_dataset) # 38
for num_return_sequences in range(3, 0, -1):
gen_config.num_return_sequences = num_return_sequences
metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config)
assert (
metrics["eval_samples"] == dataset_len * num_return_sequences
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"