Black 20 release
This commit is contained in:
@@ -101,7 +101,13 @@ class T5ModelTester:
|
||||
)
|
||||
|
||||
def check_prepare_lm_labels_via_shift_left(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -134,7 +140,13 @@ class T5ModelTester:
|
||||
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -160,7 +172,13 @@ class T5ModelTester:
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
|
||||
def create_and_check_with_lm_head(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
outputs = model(
|
||||
@@ -174,7 +192,13 @@ class T5ModelTester:
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_decoder_model_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder().to(torch_device).eval()
|
||||
# first forward pass
|
||||
@@ -205,7 +229,13 @@ class T5ModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_decoder_model_attention_mask_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder()
|
||||
model.to(torch_device)
|
||||
@@ -231,7 +261,8 @@ class T5ModelTester:
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
@@ -249,7 +280,13 @@ class T5ModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_generate_with_past_key_value_states(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
torch.manual_seed(0)
|
||||
@@ -261,14 +298,26 @@ class T5ModelTester:
|
||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||
|
||||
def create_and_check_model_fp16_forward(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).to(torch_device).half().eval()
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_encoder_decoder_shared_weights(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
for model_class in [T5Model, T5ForConditionalGeneration]:
|
||||
torch.manual_seed(0)
|
||||
@@ -339,7 +388,14 @@ class T5ModelTester:
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
@@ -412,7 +468,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = T5Model(config_and_inputs[0]).to(torch_device)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
||||
model,
|
||||
config_and_inputs[1],
|
||||
f"{tmpdirname}/t5_test.onnx",
|
||||
export_params=True,
|
||||
opset_version=9,
|
||||
)
|
||||
|
||||
|
||||
@@ -469,7 +529,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertListEqual(
|
||||
expected_summaries, decoded,
|
||||
expected_summaries,
|
||||
decoded,
|
||||
)
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user