[GenerationOutputs] Fix GenerationOutputs Tests (#9443)
* fix generation models * fix led * fix docs * add is_decoder * fix last docstrings * make style * fix t5 cross attentions * correct t5
This commit is contained in:
committed by
GitHub
parent
0c96262f7d
commit
b8462b5b2a
@@ -126,8 +126,7 @@ class SampleDecoderOnlyOutput(ModelOutput):
|
|||||||
sequence_length)`.
|
sequence_length)`.
|
||||||
hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences * batch_size, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(num_return_sequences*batch_size, generated_length, hidden_size)`.
|
||||||
hidden_size)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -153,8 +152,8 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
|||||||
at each generation step. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of
|
at each generation step. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of
|
||||||
shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`).
|
shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`).
|
||||||
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size *
|
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape
|
||||||
num_return_sequences, num_heads, sequence_length, sequence_length)`.
|
:obj:`(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
|
||||||
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
of shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`.
|
of shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`.
|
||||||
@@ -164,8 +163,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
|||||||
sequence_length)`.
|
sequence_length)`.
|
||||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_return_sequences, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
||||||
hidden_size)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -190,8 +188,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||||
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size
|
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape
|
||||||
* num_beams * num_return_sequences, config.vocab_size)`).
|
:obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||||
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
||||||
@@ -225,8 +223,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||||
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size
|
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape
|
||||||
* num_beams, config.vocab_size)`).
|
:obj:`(batch_size*num_beams, config.vocab_size)`).
|
||||||
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size,
|
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size,
|
||||||
@@ -267,8 +265,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||||
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size
|
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape
|
||||||
* num_beams * num_return_sequences, config.vocab_size)`).
|
:obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||||
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
||||||
@@ -301,8 +299,8 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
scores (:obj:`tuple(torch.FloatTensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||||
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape :obj:`(batch_size
|
. :obj:`(max_length,)`-shaped tuple of :obj:`torch.FloatTensor` with each tensor of shape
|
||||||
* num_beams, config.vocab_size)`).
|
:obj:`(batch_size*num_beams, config.vocab_size)`).
|
||||||
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size,
|
Tuple of :obj:`torch.FloatTensor` (one for each layer of the decoder) of shape :obj:`(batch_size,
|
||||||
num_heads, sequence_length, sequence_length)`.
|
num_heads, sequence_length, sequence_length)`.
|
||||||
|
|||||||
@@ -1227,7 +1227,7 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
if past is not None:
|
if past is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -570,7 +570,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
if past is not None:
|
if past is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -455,7 +455,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"past_key_values": past,
|
"past_key_values": decoder_inputs["past_key_values"],
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
}
|
}
|
||||||
return input_dict
|
return input_dict
|
||||||
|
|||||||
@@ -962,7 +962,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
if past is not None:
|
if past is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -1148,6 +1148,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
|
encoder_config.is_decoder = False
|
||||||
encoder_config.use_cache = False
|
encoder_config.use_cache = False
|
||||||
encoder_config.is_encoder_decoder = False
|
encoder_config.is_encoder_decoder = False
|
||||||
self.encoder = T5Stack(encoder_config, self.shared)
|
self.encoder = T5Stack(encoder_config, self.shared)
|
||||||
@@ -1325,6 +1326,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
|
encoder_config.is_decoder = False
|
||||||
encoder_config.use_cache = False
|
encoder_config.use_cache = False
|
||||||
encoder_config.is_encoder_decoder = False
|
encoder_config.is_encoder_decoder = False
|
||||||
self.encoder = T5Stack(encoder_config, self.shared)
|
self.encoder = T5Stack(encoder_config, self.shared)
|
||||||
|
|||||||
@@ -1132,7 +1132,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
if past is not None:
|
if past is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
|
|||||||
@@ -522,6 +522,7 @@ class GenerationTesterMixin:
|
|||||||
return
|
return
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
output_greedy, output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -730,6 +731,7 @@ class GenerationTesterMixin:
|
|||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_beam, output_generate = self._beam_search_generate(
|
output_beam, output_generate = self._beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -962,12 +964,7 @@ class GenerationTesterMixin:
|
|||||||
# Attentions
|
# Attentions
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder
|
# encoder
|
||||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||||
self.assertIsInstance(output.encoder_attentions, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_attentions.shape for layer_attentions in output.encoder_attentions],
|
|
||||||
[encoder_expected_shape] * len(output.encoder_attentions),
|
|
||||||
)
|
|
||||||
# decoder
|
# decoder
|
||||||
self._check_attentions_for_generate(
|
self._check_attentions_for_generate(
|
||||||
num_sequences_in_output,
|
num_sequences_in_output,
|
||||||
@@ -993,11 +990,8 @@ class GenerationTesterMixin:
|
|||||||
# Hidden States
|
# Hidden States
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder
|
# encoder
|
||||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
self._check_encoder_hidden_states_for_generate(
|
||||||
self.assertIsInstance(output.encoder_hidden_states, tuple)
|
output.encoder_hidden_states, batch_size, config, seq_length
|
||||||
self.assertListEqual(
|
|
||||||
[layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states],
|
|
||||||
[encoder_expected_shape] * len(output.encoder_hidden_states),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
@@ -1052,6 +1046,14 @@ class GenerationTesterMixin:
|
|||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||||
|
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
||||||
|
self.assertIsInstance(attentions, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_attentions.shape for layer_attentions in attentions],
|
||||||
|
[encoder_expected_shape] * len(attentions),
|
||||||
|
)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||||
):
|
):
|
||||||
@@ -1071,6 +1073,14 @@ class GenerationTesterMixin:
|
|||||||
[expected_shape] * len(iter_hidden_states),
|
[expected_shape] * len(iter_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||||
|
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||||
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||||
|
[encoder_expected_shape] * len(hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class UtilsFunctionsTest(unittest.TestCase):
|
class UtilsFunctionsTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -327,6 +327,32 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
# longformer cannot keep gradients in attentions or hidden states
|
# longformer cannot keep gradients in attentions or hidden states
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||||
|
# make sure tgt_length is padded
|
||||||
|
tgt_length = (
|
||||||
|
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
|
||||||
|
) * config.attention_window[0]
|
||||||
|
|
||||||
|
encoder_expected_shape = (batch_size, config.num_attention_heads, tgt_length, seq_length)
|
||||||
|
self.assertIsInstance(attentions, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_attentions.shape for layer_attentions in attentions],
|
||||||
|
[encoder_expected_shape] * len(attentions),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||||
|
# make sure seq_length is padded
|
||||||
|
seq_length = (
|
||||||
|
seq_length // config.attention_window[0] + (seq_length % config.attention_window[0] != 0)
|
||||||
|
) * config.attention_window[0]
|
||||||
|
|
||||||
|
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||||
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||||
|
[encoder_expected_shape] * len(hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
|||||||
Reference in New Issue
Block a user