[examples] SummarizationDataset cleanup (#3451)

This commit is contained in:
Sam Shleifer
2020-04-07 19:05:58 -04:00
committed by GitHub
parent b0ad069517
commit e344e3d402
4 changed files with 125 additions and 79 deletions

View File

@@ -47,27 +47,29 @@ def set_seed(args):
class BaseTransformer(pl.LightningModule):
def __init__(self, hparams, num_labels=None, mode="base"):
def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs):
"Initialize a model."
super(BaseTransformer, self).__init__()
self.hparams = hparams
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
self.hparams.model_type = self.hparams.model_type.lower()
config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
cache_dir=cache_dir,
**config_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
do_lower_case=self.hparams.do_lower_case,
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
cache_dir=cache_dir,
)
model = MODEL_MODES[mode].from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=config,
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
cache_dir=cache_dir,
)
self.config, self.tokenizer, self.model = config, tokenizer, model