[cleanup] T5 test, warnings (#5761)
This commit is contained in:
@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
if "t5" in model_name:
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
|
||||
device
|
||||
)
|
||||
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
|
||||
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
|
||||
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
@@ -2,6 +2,7 @@ import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
|
||||
return iter(sort_idx)
|
||||
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
# update config with summarization specific params
|
||||
"""Update config with summarization specific params."""
|
||||
task_specific_params = model.config.task_specific_params
|
||||
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get(task, {}))
|
||||
pars = task_specific_params.get(task, {})
|
||||
logger.info(f"using task specific params for {task}: {pars}")
|
||||
model.config.update(pars)
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
|
||||
Reference in New Issue
Block a user