fix reformer fp16 (#6237)

This commit is contained in:
Patrick von Platen
2020-08-04 13:02:25 +02:00
committed by GitHub
parent 7ea9b2db37
commit 7f65daa2e1

View File

@@ -389,7 +389,7 @@ class ReformerModelTester:
model.to(torch_device) model.to(torch_device)
model.half() model.half()
model.eval() model.eval()
output = model(input_ids, attention_mask=input_mask)["last_input_state"] output = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):