Black 20 release
This commit is contained in:
@@ -330,7 +330,9 @@ class ModelTesterMixin:
|
||||
# Prepare head_mask
|
||||
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
|
||||
head_mask = torch.ones(
|
||||
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
|
||||
self.model_tester.num_hidden_layers,
|
||||
self.model_tester.num_attention_heads,
|
||||
device=torch_device,
|
||||
)
|
||||
head_mask[0, 0] = 0
|
||||
head_mask[-1, :-1] = 0
|
||||
@@ -370,7 +372,10 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(
|
||||
config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -399,7 +404,10 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(
|
||||
config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -432,7 +440,10 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(
|
||||
config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -463,7 +474,10 @@ class ModelTesterMixin:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(
|
||||
config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
@@ -534,7 +548,8 @@ class ModelTesterMixin:
|
||||
seq_length = self.model_tester.seq_length
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -550,7 +565,10 @@ class ModelTesterMixin:
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
@@ -570,7 +588,10 @@ class ModelTesterMixin:
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
|
||||
@@ -844,7 +865,14 @@ class ModelTesterMixin:
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2, num_return_sequences=2,))
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user