BART for summarization training with CNN/DM using pytorch-lightning

This commit is contained in:
Andre Carrera
2020-03-24 19:00:24 -06:00
committed by GitHub
parent eaabaaf750
commit 3d76df3a12
5 changed files with 252 additions and 2 deletions

View File

@@ -53,10 +53,9 @@ class BaseTransformer(pl.LightningModule):
super(BaseTransformer, self).__init__()
self.hparams = hparams
self.hparams.model_type = self.hparams.model_type.lower()
config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
num_labels=num_labels,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
)
tokenizer = AutoTokenizer.from_pretrained(