PegasusForConditionalGeneration (torch version) (#6340)

Co-authored-by: Jingqing  Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
Sam Shleifer
2020-08-11 14:31:23 -04:00
committed by GitHub
parent f6cb0f806e
commit 66fa8ceaea
20 changed files with 860 additions and 20 deletions

View File

@@ -23,14 +23,13 @@ RO_CODE = 250020
@require_torch
class AbstractMBartIntegrationTest(unittest.TestCase):
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
maxDiff = 1000 # longer string compare tracebacks
checkpoint_name = None
@classmethod
def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
@@ -43,7 +42,7 @@ class AbstractMBartIntegrationTest(unittest.TestCase):
@require_torch
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name = "facebook/mbart-large-en-ro"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
@@ -73,7 +72,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
]
),
}
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
net_input["attention_mask"] = net_input["input_ids"].ne(1)
with torch.no_grad():
logits, *other_stuff = model(**net_input)
@@ -125,7 +124,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
@require_torch
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name = "facebook/mbart-large-cc25"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",