Fix doc errors and typos across the board (#8139)
* Fix doc errors and typos across the board * Fix a typo * Fix the CI * Fix more typos * Fix CI * More fixes * Fix CI * More fixes * More fixes
This commit is contained in:
@@ -25,14 +25,14 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class TFGenerationMixin:
|
||||
"""
|
||||
A class contraining all of the functions supporting generation, to be used as a mixin in
|
||||
:class:`~transfomers.TFPreTrainedModel`.
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
:class:`~transformers.TFPreTrainedModel`.
|
||||
"""
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transfomers.TFPreTrainedModel` for custom behavior to prepare inputs in the
|
||||
generate method.
|
||||
Implement in subclasses of :class:`~transformers.TFPreTrainedModel` for custom behavior to prepare inputs in
|
||||
the generate method.
|
||||
"""
|
||||
return {"inputs": inputs}
|
||||
|
||||
@@ -216,17 +216,17 @@ class TFGenerationMixin:
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
|
||||
else:
|
||||
batch_size = 1
|
||||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictely positive."
|
||||
assert temperature > 0, "`temperature` should be strictly positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
@@ -239,10 +239,10 @@ class TFGenerationMixin:
|
||||
assert (eos_token_id is None) or (
|
||||
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||
), "`eos_token_id` should be a positive integer."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||
assert (
|
||||
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
||||
), "`num_return_sequences` should be a strictely positive integer."
|
||||
), "`num_return_sequences` should be a strictly positive integer."
|
||||
assert (
|
||||
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
|
||||
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
|
||||
@@ -722,7 +722,7 @@ class TFGenerationMixin:
|
||||
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
# re-organize to group the beam together (we are keeping top hypothesis across beams)
|
||||
next_scores = tf.reshape(
|
||||
next_scores, (batch_size, num_beams * vocab_size)
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
@@ -897,7 +897,7 @@ class TFGenerationMixin:
|
||||
|
||||
def adjust_logits_during_generation(self, logits, **kwargs):
|
||||
"""
|
||||
Implement in subclasses of :class:`~transfomers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
|
||||
the generate method.
|
||||
"""
|
||||
return logits
|
||||
@@ -978,7 +978,7 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
||||
|
||||
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
||||
"""
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filterin
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
@@ -1047,7 +1047,7 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
|
||||
|
||||
def sample_without_replacement(logits, num_samples):
|
||||
"""
|
||||
categorical sampling witouth replacement is currently not implemented the gumbel-max trick will do for now see
|
||||
categorical sampling without replacement is currently not implemented the gumbel-max trick will do for now see
|
||||
https://github.com/tensorflow/tensorflow/issues/9260 for more info
|
||||
"""
|
||||
z = -tf.math.log(tf.random.uniform(shape_list(logits), 0, 1))
|
||||
|
||||
Reference in New Issue
Block a user