[EncoderDecoder] Fix initialization and save/load bug (#4680)
* fix bug * add more tests
This commit is contained in:
committed by
GitHub
parent
6f82aea66b
commit
0866669e75
@@ -22,6 +22,7 @@ from transformers import is_torch_available
|
||||
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
|
||||
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .utils import require_torch, slow, torch_device
|
||||
|
||||
|
||||
@@ -331,3 +332,33 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
def test_real_bert_model_from_pretrained(self):
|
||||
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_real_bert_model_from_pretrained_has_cross_attention(self):
|
||||
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
|
||||
|
||||
@slow
|
||||
def test_real_bert_model_save_load_from_pretrained(self):
|
||||
model_2 = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
model_2.to(torch_device)
|
||||
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
|
||||
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
|
||||
attention_mask = ids_tensor([13, 5], vocab_size=2)
|
||||
with torch.no_grad():
|
||||
outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model_2.save_pretrained(tmp_dirname)
|
||||
model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||
model_1.to(torch_device)
|
||||
|
||||
after_outputs = model_1(
|
||||
input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
Reference in New Issue
Block a user