Fix doc errors and typos across the board (#8139)
* Fix doc errors and typos across the board * Fix a typo * Fix the CI * Fix more typos * Fix CI * More fixes * Fix CI * More fixes * More fixes
This commit is contained in:
@@ -29,20 +29,20 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class contraining all of the functions supporting generation, to be used as a mixin in
|
||||
:class:`~transfomers.PreTrainedModel`.
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
:class:`~transformers.PreTrainedModel`.
|
||||
"""
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to prepare inputs in the
|
||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
|
||||
generate method.
|
||||
"""
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
the generate method.
|
||||
"""
|
||||
return logits
|
||||
@@ -285,7 +285,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||
batch_size = input_ids.shape[0] # overridden by the input batch_size
|
||||
else:
|
||||
batch_size = 1
|
||||
|
||||
@@ -533,7 +533,7 @@ class GenerationMixin:
|
||||
):
|
||||
"""
|
||||
Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated
|
||||
independantly.
|
||||
independently.
|
||||
"""
|
||||
# length of generated sentences / unfinished sentences
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
@@ -600,7 +600,7 @@ class GenerationMixin:
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents.mul_((~eos_in_sents).long())
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
||||
if unfinished_sents.max() == 0:
|
||||
break
|
||||
|
||||
@@ -724,7 +724,7 @@ class GenerationMixin:
|
||||
else:
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
# re-organize to group the beam together (we are keeping top hypothesis across beams)
|
||||
next_scores = next_scores.view(
|
||||
batch_size, num_beams * vocab_size
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
@@ -969,7 +969,7 @@ def top_k_top_p_filtering(
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filterin
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
|
||||
Reference in New Issue
Block a user