[cleanup] T5 test, warnings (#5761)
This commit is contained in:
@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
|
|||||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||||
if "t5" in model_name:
|
if "t5" in model_name:
|
||||||
batch = [model.config.prefix + text for text in batch]
|
batch = [model.config.prefix + text for text in batch]
|
||||||
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
|
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
|
||||||
device
|
|
||||||
)
|
|
||||||
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
|
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
|
||||||
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
|
||||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import itertools
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, List
|
from typing import Callable, Dict, Iterable, List
|
||||||
|
|
||||||
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
|
|||||||
return iter(sort_idx)
|
return iter(sort_idx)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def use_task_specific_params(model, task):
|
def use_task_specific_params(model, task):
|
||||||
# update config with summarization specific params
|
"""Update config with summarization specific params."""
|
||||||
task_specific_params = model.config.task_specific_params
|
task_specific_params = model.config.task_specific_params
|
||||||
|
|
||||||
if task_specific_params is not None:
|
if task_specific_params is not None:
|
||||||
model.config.update(task_specific_params.get(task, {}))
|
pars = task_specific_params.get(task, {})
|
||||||
|
logger.info(f"using task specific params for {task}: {pars}")
|
||||||
|
model.config.update(pars)
|
||||||
|
|
||||||
|
|
||||||
def pickle_load(path):
|
def pickle_load(path):
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user