[cleanup] T5 test, warnings (#5761)
This commit is contained in:
@@ -2,6 +2,7 @@ import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
|
||||
return iter(sort_idx)
|
||||
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
# update config with summarization specific params
|
||||
"""Update config with summarization specific params."""
|
||||
task_specific_params = model.config.task_specific_params
|
||||
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get(task, {}))
|
||||
pars = task_specific_params.get(task, {})
|
||||
logger.info(f"using task specific params for {task}: {pars}")
|
||||
model.config.update(pars)
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
|
||||
Reference in New Issue
Block a user