Have seq2seq just use gather (#27025)
* Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -12,8 +12,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
BertTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
EncoderDecoderModel,
|
||||
GenerationConfig,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow
|
||||
from transformers.utils import is_datasets_available
|
||||
|
||||
@@ -124,3 +132,52 @@ class Seq2seqTrainerTester(TestCasePlus):
|
||||
|
||||
# start training
|
||||
trainer.train()
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_return_sequences(self):
|
||||
# Tests that the number of generated sequences is correct when num_return_sequences > 1
|
||||
# and essentially ensuring that `accelerator.gather()` is used instead of `gather_for_metrics`
|
||||
INPUT_COLUMN = "question"
|
||||
TARGET_COLUMN = "answer"
|
||||
MAX_INPUT_LENGTH = 256
|
||||
MAX_TARGET_LENGTH = 256
|
||||
|
||||
dataset = datasets.load_dataset("gsm8k", "main", split="train[:38]")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
|
||||
gen_config = GenerationConfig.from_pretrained(
|
||||
"t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5
|
||||
)
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda x: {"samples": x[0].shape[0]},
|
||||
)
|
||||
|
||||
def prepare_data(examples):
|
||||
# Remove pairs where at least one record is none
|
||||
inputs = examples[INPUT_COLUMN]
|
||||
targets = examples[TARGET_COLUMN]
|
||||
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
|
||||
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
prepared_dataset = dataset.map(prepare_data, batched=True, remove_columns=[INPUT_COLUMN, TARGET_COLUMN])
|
||||
dataset_len = len(prepared_dataset) # 38
|
||||
|
||||
for num_return_sequences in range(3, 0, -1):
|
||||
gen_config.num_return_sequences = num_return_sequences
|
||||
metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config)
|
||||
assert (
|
||||
metrics["eval_samples"] == dataset_len * num_return_sequences
|
||||
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"
|
||||
|
||||
Reference in New Issue
Block a user