[XGLM] run sampling test on CPU to be deterministic (#15892)

* run sampling test on CPU to be deterministic

* input_ids on CPU
This commit is contained in:
Suraj Patil
2022-03-02 17:55:49 +01:00
committed by GitHub
parent baab5e7cdf
commit 130b987880

View File

@@ -418,15 +418,14 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
def test_xglm_sample(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")
model.to(torch_device)
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt")
input_ids = tokenized.input_ids.to(torch_device)
input_ids = tokenized.input_ids
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and I am happy to show you all about a recent project for my"
EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@slow