[Beam Search] Correct returned beam scores (#14654)
* better * save intermediate * finish code * up * docs * Apply suggestions from code review * up * add compute transition beam scores function to model and make sure scores are correct with eos * apply nicos comments * Apply suggestions from code review * another fix
This commit is contained in:
committed by
GitHub
parent
e239fc3b0b
commit
8d6acc6c29
@@ -208,10 +208,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Final beam scores of the generated `sequences`.
|
Final beam scores of the generated `sequences`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||||
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
|
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||||
|
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||||
@@ -223,6 +226,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
@@ -241,10 +245,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Final beam scores of the generated `sequences`.
|
Final beam scores of the generated `sequences`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||||
config.vocab_size)`).
|
config.vocab_size)`).
|
||||||
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
|
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||||
|
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||||
@@ -267,6 +274,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
@@ -286,10 +294,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Final beam scores of the generated `sequences`.
|
Final beam scores of the generated `sequences`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||||
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
|
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||||
|
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||||
@@ -301,6 +312,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
@@ -319,10 +331,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Final beam scores of the generated `sequences`.
|
Final beam scores of the generated `sequences`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||||
config.vocab_size)`).
|
config.vocab_size)`).
|
||||||
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
|
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||||
|
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length, sequence_length)`.
|
||||||
@@ -343,6 +358,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
@@ -743,6 +759,45 @@ class GenerationMixin:
|
|||||||
default_list.extend(custom_list)
|
default_list.extend(custom_list)
|
||||||
return default_list
|
return default_list
|
||||||
|
|
||||||
|
def compute_transition_beam_scores(
|
||||||
|
self,
|
||||||
|
sequences: torch.Tensor,
|
||||||
|
scores: Tuple[torch.Tensor],
|
||||||
|
beam_indices: torch.Tensor,
|
||||||
|
eos_token_id: int = None,
|
||||||
|
):
|
||||||
|
"""compute the transition probabilities of sequences given generation
|
||||||
|
scores and beam indices"""
|
||||||
|
|
||||||
|
# reshape scores as [vocab_size * batch_size, # generation steps]
|
||||||
|
# with batch_size being 2 * vocab_size and # generation steps being
|
||||||
|
# seq_len - input_length
|
||||||
|
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
|
||||||
|
|
||||||
|
# start of generated tokens
|
||||||
|
cut_idx = sequences.shape[-1] - scores.shape[-1]
|
||||||
|
# adjust for beam indices
|
||||||
|
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
|
||||||
|
# compute real indices
|
||||||
|
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
||||||
|
# gather scores and run
|
||||||
|
transition_scores = scores.gather(0, indices)
|
||||||
|
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
|
||||||
|
# get first occurence of EOS token
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
|
||||||
|
if eos_token_id is not None:
|
||||||
|
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
|
||||||
|
# make sure first eos token still contributes to transition probs
|
||||||
|
is_eos_token_id[:, -1] = False
|
||||||
|
is_eos_token_id = is_eos_token_id.roll(1, -1)
|
||||||
|
# all indices after eos shoud be masked
|
||||||
|
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
|
||||||
|
# zero out padded probs
|
||||||
|
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
|
||||||
|
|
||||||
|
return transition_scores
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -1871,8 +1926,21 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
|
num_beams = beam_scorer.num_beams
|
||||||
|
|
||||||
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
|
if num_beams * batch_size != batch_beam_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||||
|
)
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
|
beam_indices = (
|
||||||
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
||||||
|
)
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
@@ -1884,16 +1952,6 @@ class GenerationMixin:
|
|||||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
|
||||||
num_beams = beam_scorer.num_beams
|
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
|
||||||
|
|
||||||
if num_beams * batch_size != batch_beam_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
beam_scores[:, 1:] = -1e9
|
beam_scores[:, 1:] = -1e9
|
||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
@@ -1932,13 +1990,13 @@ class GenerationMixin:
|
|||||||
next_token_logits, dim=-1
|
next_token_logits, dim=-1
|
||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||||
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
# Store scores, attentions and hidden_states when required
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if output_scores:
|
if output_scores:
|
||||||
scores += (next_token_scores,)
|
scores += (next_token_scores_processed,)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
@@ -1973,6 +2031,7 @@ class GenerationMixin:
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
beam_scores = beam_outputs["next_beam_scores"]
|
beam_scores = beam_outputs["next_beam_scores"]
|
||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
beam_idx = beam_outputs["next_beam_indices"]
|
beam_idx = beam_outputs["next_beam_indices"]
|
||||||
@@ -1985,6 +2044,9 @@ class GenerationMixin:
|
|||||||
if model_kwargs["past"] is not None:
|
if model_kwargs["past"] is not None:
|
||||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||||
|
|
||||||
|
if return_dict_in_generate and output_scores:
|
||||||
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||||
|
|
||||||
# increase cur_len
|
# increase cur_len
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
@@ -2007,11 +2069,20 @@ class GenerationMixin:
|
|||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
|
else:
|
||||||
|
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||||
|
# return only as many indices as sequences
|
||||||
|
beam_indices = tuple(
|
||||||
|
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||||
|
)
|
||||||
|
beam_indices = sum(beam_indices, ())
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSearchEncoderDecoderOutput(
|
return BeamSearchEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
beam_indices=beam_indices,
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
@@ -2023,6 +2094,7 @@ class GenerationMixin:
|
|||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
beam_indices=beam_indices,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
@@ -2175,8 +2247,16 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
|
num_beams = beam_scorer.num_beams
|
||||||
|
|
||||||
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
|
beam_indices = (
|
||||||
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
||||||
|
)
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
@@ -2188,11 +2268,6 @@ class GenerationMixin:
|
|||||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
|
||||||
num_beams = beam_scorer.num_beams
|
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
|
||||||
|
|
||||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|
||||||
@@ -2231,14 +2306,14 @@ class GenerationMixin:
|
|||||||
next_token_logits, dim=-1
|
next_token_logits, dim=-1
|
||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
next_token_scores = logits_processor(input_ids, next_token_scores)
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||||
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
# Store scores, attentions and hidden_states when required
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if output_scores:
|
if output_scores:
|
||||||
scores += (next_token_scores,)
|
scores += (logits_warper(input_ids, next_token_scores_processed),)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
@@ -2289,6 +2364,9 @@ class GenerationMixin:
|
|||||||
if model_kwargs["past"] is not None:
|
if model_kwargs["past"] is not None:
|
||||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||||
|
|
||||||
|
if return_dict_in_generate and output_scores:
|
||||||
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||||
|
|
||||||
# increase cur_len
|
# increase cur_len
|
||||||
cur_len = cur_len + 1
|
cur_len = cur_len + 1
|
||||||
|
|
||||||
@@ -2311,11 +2389,20 @@ class GenerationMixin:
|
|||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
|
else:
|
||||||
|
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||||
|
# return only as many indices as sequences
|
||||||
|
beam_indices = tuple(
|
||||||
|
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||||
|
)
|
||||||
|
beam_indices = sum(beam_indices, ())
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSampleEncoderDecoderOutput(
|
return BeamSampleEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
beam_indices=beam_indices,
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
@@ -2327,6 +2414,7 @@ class GenerationMixin:
|
|||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
beam_indices=beam_indices,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
@@ -2472,6 +2560,24 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
|
num_beams = beam_scorer.num_beams
|
||||||
|
num_beam_groups = beam_scorer.num_beam_groups
|
||||||
|
num_sub_beams = num_beams // num_beam_groups
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
|
if return_dict_in_generate and output_scores:
|
||||||
|
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
||||||
|
else:
|
||||||
|
beam_indices = None
|
||||||
|
|
||||||
|
if num_beams * batch_size != batch_beam_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||||
|
)
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
@@ -2485,19 +2591,6 @@ class GenerationMixin:
|
|||||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
|
||||||
num_beams = beam_scorer.num_beams
|
|
||||||
num_beam_groups = beam_scorer.num_beam_groups
|
|
||||||
num_sub_beams = num_beams // num_beam_groups
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
|
||||||
|
|
||||||
if num_beams * batch_size != batch_beam_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||||
# the same group don't produce same tokens everytime.
|
# the same group don't produce same tokens everytime.
|
||||||
@@ -2564,15 +2657,14 @@ class GenerationMixin:
|
|||||||
) # (batch_size * group_size, vocab_size)
|
) # (batch_size * group_size, vocab_size)
|
||||||
vocab_size = next_token_scores.shape[-1]
|
vocab_size = next_token_scores.shape[-1]
|
||||||
|
|
||||||
next_token_scores = logits_processor(
|
next_token_scores_processed = logits_processor(
|
||||||
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
||||||
)
|
)
|
||||||
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as(
|
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
||||||
next_token_scores
|
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
||||||
)
|
|
||||||
|
|
||||||
if output_scores:
|
if output_scores:
|
||||||
processed_score[batch_group_indices] = next_token_scores
|
processed_score[batch_group_indices] = next_token_scores_processed
|
||||||
|
|
||||||
# reshape for beam search
|
# reshape for beam search
|
||||||
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
||||||
@@ -2597,6 +2689,11 @@ class GenerationMixin:
|
|||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
beam_idx = beam_outputs["next_beam_indices"]
|
beam_idx = beam_outputs["next_beam_indices"]
|
||||||
|
|
||||||
|
if return_dict_in_generate and output_scores:
|
||||||
|
beam_indices[beam_group_idx] = tuple(
|
||||||
|
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
|
||||||
|
)
|
||||||
|
|
||||||
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
||||||
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||||
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
||||||
@@ -2655,11 +2752,21 @@ class GenerationMixin:
|
|||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
|
else:
|
||||||
|
beam_indices = sum(beam_indices, ())
|
||||||
|
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||||
|
# return only as many indices as sequences
|
||||||
|
beam_indices = tuple(
|
||||||
|
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||||
|
)
|
||||||
|
beam_indices = sum(beam_indices, ())
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSearchEncoderDecoderOutput(
|
return BeamSearchEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
beam_indices=beam_indices,
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
|
|||||||
@@ -1903,3 +1903,147 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
||||||
|
|
||||||
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
||||||
|
|
||||||
|
def test_transition_scores_beam_search_encoder_decoder(self):
|
||||||
|
articles = [
|
||||||
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||||
|
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||||
|
]
|
||||||
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
model = BartForConditionalGeneration.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bart",
|
||||||
|
max_length=10,
|
||||||
|
num_beams=4,
|
||||||
|
num_return_sequences=2,
|
||||||
|
eos_token_id=None,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
length_penalty=0.0,
|
||||||
|
)
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||||
|
outputs = model.generate(input_ids=input_ids)
|
||||||
|
|
||||||
|
transition_scores = model.compute_transition_beam_scores(
|
||||||
|
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||||
|
)
|
||||||
|
transition_scores_sum = transition_scores.sum(-1)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
|
||||||
|
articles = [
|
||||||
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||||
|
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||||
|
]
|
||||||
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
model = BartForConditionalGeneration.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bart",
|
||||||
|
max_length=10,
|
||||||
|
num_beams=4,
|
||||||
|
num_return_sequences=2,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
length_penalty=0.0,
|
||||||
|
)
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||||
|
outputs = model.generate(input_ids=input_ids)
|
||||||
|
|
||||||
|
transition_scores = model.compute_transition_beam_scores(
|
||||||
|
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||||
|
)
|
||||||
|
transition_scores_sum = transition_scores.sum(-1)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
def test_transition_scores_beam_search_decoder_only(self):
|
||||||
|
articles = [
|
||||||
|
"Justin Timberlake",
|
||||||
|
"Michael Phelps",
|
||||||
|
]
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-gpt2",
|
||||||
|
max_length=10,
|
||||||
|
num_beams=4,
|
||||||
|
num_return_sequences=2,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
eos_token_id=None,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
length_penalty=0.0,
|
||||||
|
)
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||||
|
outputs = model.generate(input_ids=input_ids)
|
||||||
|
|
||||||
|
transition_scores = model.compute_transition_beam_scores(
|
||||||
|
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||||
|
)
|
||||||
|
transition_scores_sum = transition_scores.sum(-1)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
def test_transition_scores_beam_sample_encoder_decoder(self):
|
||||||
|
articles = [
|
||||||
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||||
|
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||||
|
]
|
||||||
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
model = BartForConditionalGeneration.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bart",
|
||||||
|
do_sample=True,
|
||||||
|
max_length=10,
|
||||||
|
num_beams=4,
|
||||||
|
num_return_sequences=2,
|
||||||
|
eos_token_id=None,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
length_penalty=0.0,
|
||||||
|
)
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||||
|
outputs = model.generate(input_ids=input_ids)
|
||||||
|
|
||||||
|
transition_scores = model.compute_transition_beam_scores(
|
||||||
|
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||||
|
)
|
||||||
|
transition_scores_sum = transition_scores.sum(-1)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
def test_transition_scores_group_beam_search_encoder_decoder(self):
|
||||||
|
articles = [
|
||||||
|
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||||
|
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||||
|
]
|
||||||
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
model = BartForConditionalGeneration.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bart",
|
||||||
|
max_length=10,
|
||||||
|
num_beams=2,
|
||||||
|
num_beam_groups=2,
|
||||||
|
num_return_sequences=2,
|
||||||
|
eos_token_id=None,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
length_penalty=0.0,
|
||||||
|
)
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||||
|
outputs = model.generate(input_ids=input_ids)
|
||||||
|
|
||||||
|
transition_scores = model.compute_transition_beam_scores(
|
||||||
|
outputs.sequences, outputs.scores, outputs.beam_indices
|
||||||
|
)
|
||||||
|
transition_scores_sum = transition_scores.sum(-1)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user