Fix 3 failing slow bart/blender tests (#7652)

This commit is contained in:
Sam Shleifer
2020-10-07 22:05:03 -04:00
committed by GitHub
parent 960faaaf28
commit e3e6517355
2 changed files with 8 additions and 18 deletions

View File

@@ -134,38 +134,31 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
class Blenderbot3BIntegrationTests(unittest.TestCase):
ckpt = "facebook/blenderbot-3B"
@cached_property
def model(self):
model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).to(torch_device)
if torch_device == "cuda":
model = model.half()
return model
@cached_property
def tokenizer(self):
return BlenderbotTokenizer.from_pretrained(self.ckpt)
@slow
def test_generation_from_short_input_same_as_parlai_3B(self):
torch.cuda.empty_cache()
model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).half().to(torch_device)
src_text = ["Sam"]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
generated_utterances = self.model.generate(**model_inputs, **FASTER_GEN_KWARGS)
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
generated_txt = self.tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
assert generated_txt[0].strip() == tgt_text
@slow
def test_generation_from_long_input_same_as_parlai_3B(self):
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
generated_ids = self.model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
assert "I think it's because we are so worried about what people think of us." == reply.strip()
del model
@require_torch
@@ -193,7 +186,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
assert self.model.config.do
generated_ids = self.model.generate(**model_inputs)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)