[cleanup] T5 test, warnings (#5761)

This commit is contained in:
Sam Shleifer
2020-07-15 08:23:22 -04:00
committed by GitHub
parent ec0a945cf9
commit d0486c8bc2
3 changed files with 55 additions and 97 deletions

View File

@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
for batch in tqdm(list(chunks(examples, batch_size))): for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name: if "t5" in model_name:
batch = [model.config.prefix + text for text in batch] batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to( batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
device
)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)

View File

@@ -2,6 +2,7 @@ import itertools
import json import json
import os import os
import pickle import pickle
from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List from typing import Callable, Dict, Iterable, List
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
return iter(sort_idx) return iter(sort_idx)
logger = getLogger(__name__)
def use_task_specific_params(model, task): def use_task_specific_params(model, task):
# update config with summarization specific params """Update config with summarization specific params."""
task_specific_params = model.config.task_specific_params task_specific_params = model.config.task_specific_params
if task_specific_params is not None: if task_specific_params is not None:
model.config.update(task_specific_params.get(task, {})) pars = task_specific_params.get(task, {})
logger.info(f"using task specific params for {task}: {pars}")
model.config.update(pars)
def pickle_load(path): def pickle_load(path):

File diff suppressed because one or more lines are too long