Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -101,7 +101,13 @@ class T5ModelTester:
)
def check_prepare_lm_labels_via_shift_left(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
@@ -134,7 +140,13 @@ class T5ModelTester:
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_model(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
@@ -160,7 +172,13 @@ class T5ModelTester:
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_with_lm_head(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
outputs = model(
@@ -174,7 +192,13 @@ class T5ModelTester:
self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5Model(config=config).get_decoder().to(torch_device).eval()
# first forward pass
@@ -205,7 +229,13 @@ class T5ModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5Model(config=config).get_decoder()
model.to(torch_device)
@@ -231,7 +261,8 @@ class T5ModelTester:
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
@@ -249,7 +280,13 @@ class T5ModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
torch.manual_seed(0)
@@ -261,14 +298,26 @@ class T5ModelTester:
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_model_fp16_forward(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = T5Model(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_encoder_decoder_shared_weights(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
for model_class in [T5Model, T5ForConditionalGeneration]:
torch.manual_seed(0)
@@ -339,7 +388,14 @@ class T5ModelTester:
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
@@ -412,7 +468,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model = T5Model(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
model,
config_and_inputs[1],
f"{tmpdirname}/t5_test.onnx",
export_params=True,
opset_version=9,
)
@@ -469,7 +529,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
)
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
expected_summaries, decoded,
expected_summaries,
decoded,
)
@slow