[cleanup] T5 test, warnings (#5761)

This commit is contained in:
Sam Shleifer
2020-07-15 08:23:22 -04:00
committed by GitHub
parent ec0a945cf9
commit d0486c8bc2
3 changed files with 55 additions and 97 deletions

View File

@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
device
)
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)