[examples] SummarizationDataset cleanup (#3451)
This commit is contained in:
@@ -47,27 +47,29 @@ def set_seed(args):
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(self, hparams, num_labels=None, mode="base"):
|
||||
def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs):
|
||||
"Initialize a model."
|
||||
|
||||
super(BaseTransformer, self).__init__()
|
||||
self.hparams = hparams
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
self.hparams.model_type = self.hparams.model_type.lower()
|
||||
config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
do_lower_case=self.hparams.do_lower_case,
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
model = MODEL_MODES[mode].from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.config, self.tokenizer, self.model = config, tokenizer, model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user