BART for summarization training with CNN/DM using pytorch-lightning
This commit is contained in:
@@ -53,10 +53,9 @@ class BaseTransformer(pl.LightningModule):
|
||||
super(BaseTransformer, self).__init__()
|
||||
self.hparams = hparams
|
||||
self.hparams.model_type = self.hparams.model_type.lower()
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user