PegasusForConditionalGeneration (torch version) (#6340)
Co-authored-by: Jingqing Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
@@ -23,14 +23,13 @@ RO_CODE = 250020
|
||||
|
||||
|
||||
@require_torch
|
||||
class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
|
||||
maxDiff = 1000 # longer string compare tracebacks
|
||||
checkpoint_name = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
@@ -43,7 +42,7 @@ class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
@@ -73,7 +72,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
]
|
||||
),
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(1)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
@@ -125,7 +124,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||
class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-cc25"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
|
||||
Reference in New Issue
Block a user