TF: GPT-2 generation supports left-padding (#17426)
* TF GPT-2 now properly works with left padding * throw a warning when eos token == pad token and there is no attention mask
This commit is contained in:
@@ -1498,8 +1498,14 @@ class TFGenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if pad_token_id is None and eos_token_id is not None:
|
if pad_token_id is None and eos_token_id is not None:
|
||||||
|
if attention_mask is None:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
|
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
|
||||||
pad_token_id = eos_token_id
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
if min_length is not None and min_length > max_length:
|
if min_length is not None and min_length > max_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||||
@@ -1525,7 +1531,9 @@ class TFGenerationMixin:
|
|||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
|
input_ids, pad_token_id, eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Prepare model inputs which will be used for auto-regressive generation
|
# 4. Prepare model inputs which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
@@ -1653,12 +1661,17 @@ class TFGenerationMixin:
|
|||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
self,
|
self,
|
||||||
inputs: tf.Tensor,
|
inputs: tf.Tensor,
|
||||||
pad_token_id: int,
|
pad_token_id: Optional[int],
|
||||||
|
eos_token_id: Optional[int],
|
||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
|
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
|
||||||
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
|
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
|
||||||
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||||
|
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||||
|
)
|
||||||
|
|
||||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||||
if is_input_ids and is_pad_token_in_inputs:
|
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||||
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
|
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
|
||||||
else:
|
else:
|
||||||
return tf.ones(inputs.shape[:2], dtype=tf.int32)
|
return tf.ones(inputs.shape[:2], dtype=tf.int32)
|
||||||
@@ -1954,6 +1967,7 @@ class TFGenerationMixin:
|
|||||||
# 1. init greedy_search values
|
# 1. init greedy_search values
|
||||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||||
|
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
@@ -1973,10 +1987,9 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# 3. init tensors to use for "xla-compileable" generate function
|
# 3. init tensors to use for "xla-compileable" generate function
|
||||||
# define bsz, seq_length
|
batch_size, cur_len = input_ids.shape
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
|
|
||||||
# initialize `generated`, `finished_sequences`, and `current_pos`
|
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||||
generated = tf.TensorArray(
|
generated = tf.TensorArray(
|
||||||
element_shape=(batch_size,),
|
element_shape=(batch_size,),
|
||||||
dtype=tf.int32,
|
dtype=tf.int32,
|
||||||
@@ -1984,25 +1997,26 @@ class TFGenerationMixin:
|
|||||||
size=max_length,
|
size=max_length,
|
||||||
clear_after_read=False,
|
clear_after_read=False,
|
||||||
)
|
)
|
||||||
|
if pad_token_id: # ignores the cases when it is 0 or None
|
||||||
|
for i in range(max_length):
|
||||||
|
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
|
||||||
|
|
||||||
# write prompt to generated
|
# write prompt to generated
|
||||||
for i in range(seq_length):
|
for i in range(cur_len):
|
||||||
generated = generated.write(i, input_ids[:, i])
|
generated = generated.write(i, input_ids[:, i])
|
||||||
|
|
||||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||||
current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length
|
|
||||||
|
|
||||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||||
# define condition fn
|
# define condition fn
|
||||||
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
"""state termination condition fn."""
|
"""state termination condition fn."""
|
||||||
return ~tf.reduce_all(finished_sequences)
|
return ~tf.reduce_all(finished_sequences)
|
||||||
|
|
||||||
# define condition fn
|
# define condition fn
|
||||||
def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
"""state update fn."""
|
"""state update fn."""
|
||||||
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
|
||||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
|
||||||
# forward pass to get next token logits
|
# forward pass to get next token logits
|
||||||
outputs = self(
|
outputs = self(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
@@ -2029,13 +2043,8 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states.append(outputs.hidden_states)
|
decoder_hidden_states.append(outputs.hidden_states)
|
||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
# TODO(pvp, joao, matt) - all the logits processors need to be adapted
|
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||||
# to be XLA compatible
|
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||||
input_ids = None
|
|
||||||
if not use_xla:
|
|
||||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
|
||||||
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
|
||||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
|
|
||||||
|
|
||||||
# argmax
|
# argmax
|
||||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||||
@@ -2047,16 +2056,14 @@ class TFGenerationMixin:
|
|||||||
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
|
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
|
||||||
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||||
|
|
||||||
# update `generated` and `current_pos`
|
# update `generated` and `cur_len`
|
||||||
generated = generated.write(current_pos[0], next_tokens)
|
generated = generated.write(cur_len, next_tokens)
|
||||||
next_tokens = next_tokens[:, None]
|
next_tokens = next_tokens[:, None]
|
||||||
current_pos += 1
|
cur_len += 1
|
||||||
|
|
||||||
# update model_kwargs
|
# update model_kwargs
|
||||||
if use_xla:
|
if use_xla:
|
||||||
model_kwargs = self._update_model_kwargs_for_xla_generation(
|
model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
|
||||||
outputs, model_kwargs, current_pos, max_length
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
@@ -2067,24 +2074,24 @@ class TFGenerationMixin:
|
|||||||
model_kwargs.pop("past", None)
|
model_kwargs.pop("past", None)
|
||||||
|
|
||||||
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||||
next_tokens = tf.transpose(next_tokens[: current_pos[0]])
|
next_tokens = tf.transpose(next_tokens[:cur_len])
|
||||||
|
|
||||||
return generated, finished_sequences, next_tokens, current_pos, model_kwargs
|
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
|
||||||
|
|
||||||
# 5. run generation
|
# 5. run generation
|
||||||
# 1st generation step has to be run before to initialize `past`
|
# 1st generation step has to be run before to initialize `past`
|
||||||
generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn(
|
generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn(
|
||||||
generated, finished_sequences, input_ids, current_pos, model_kwargs
|
generated, finished_sequences, input_ids, cur_len, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||||
# only in case 1st generation step does NOT yield EOS token though
|
# only in case 1st generation step does NOT yield EOS token though
|
||||||
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
maximum_iterations = max_length - seq_length - 1
|
maximum_iterations = max_length - cur_len
|
||||||
generated, _, _, current_pos, _ = tf.while_loop(
|
generated, _, _, cur_len, _ = tf.while_loop(
|
||||||
greedy_search_cond_fn,
|
greedy_search_cond_fn,
|
||||||
greedy_search_body_fn,
|
greedy_search_body_fn,
|
||||||
(generated, finished_sequences, next_tokens, current_pos, model_kwargs),
|
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
|
||||||
maximum_iterations=maximum_iterations,
|
maximum_iterations=maximum_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2093,7 +2100,7 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
if not use_xla:
|
if not use_xla:
|
||||||
# cut for backward compatibility
|
# cut for backward compatibility
|
||||||
output_ids = output_ids[:, : current_pos[0]]
|
output_ids = output_ids[:, :cur_len]
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
@@ -2231,6 +2238,7 @@ class TFGenerationMixin:
|
|||||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||||
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
||||||
|
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||||
@@ -2250,10 +2258,9 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# 3. init tensors to use for "xla-compileable" generate function
|
# 3. init tensors to use for "xla-compileable" generate function
|
||||||
# define bsz, seq_length
|
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
# initialize `generated`, `finished_sequences`
|
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||||
generated = tf.TensorArray(
|
generated = tf.TensorArray(
|
||||||
element_shape=(batch_size,),
|
element_shape=(batch_size,),
|
||||||
dtype=tf.int32,
|
dtype=tf.int32,
|
||||||
@@ -2261,19 +2268,22 @@ class TFGenerationMixin:
|
|||||||
size=max_length,
|
size=max_length,
|
||||||
clear_after_read=False,
|
clear_after_read=False,
|
||||||
)
|
)
|
||||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
if pad_token_id: # ignores the cases when it is 0 or None
|
||||||
|
for i in range(max_length):
|
||||||
|
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
|
||||||
|
|
||||||
# write prompt to generated
|
# write prompt to generated
|
||||||
for i in range(cur_len):
|
for i in range(cur_len):
|
||||||
generated = generated.write(i, input_ids[:, i])
|
generated = generated.write(i, input_ids[:, i])
|
||||||
|
|
||||||
|
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||||
|
|
||||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||||
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
return ~tf.reduce_all(finished_sequences)
|
return ~tf.reduce_all(finished_sequences)
|
||||||
|
|
||||||
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
|
||||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
|
||||||
# forward pass to get next token logits
|
# forward pass to get next token logits
|
||||||
outputs = self(
|
outputs = self(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
@@ -2300,12 +2310,7 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states.append(outputs.hidden_states)
|
decoder_hidden_states.append(outputs.hidden_states)
|
||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
|
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||||
# to be XLA compatible
|
|
||||||
input_ids = None
|
|
||||||
if not use_xla:
|
|
||||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
|
||||||
input_ids = tf.transpose(input_ids[:cur_len])
|
|
||||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
|
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
|
||||||
|
|
||||||
@@ -2359,7 +2364,7 @@ class TFGenerationMixin:
|
|||||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||||
# only in case 1st generation step does NOT yield EOS token though
|
# only in case 1st generation step does NOT yield EOS token though
|
||||||
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||||
maximum_iterations = max_length - cur_len - 1
|
maximum_iterations = max_length - cur_len
|
||||||
generated, _, _, cur_len, _ = tf.while_loop(
|
generated, _, _, cur_len, _ = tf.while_loop(
|
||||||
sample_cond_fn,
|
sample_cond_fn,
|
||||||
sample_body_fn,
|
sample_body_fn,
|
||||||
@@ -2613,6 +2618,7 @@ class TFGenerationMixin:
|
|||||||
size=max_length,
|
size=max_length,
|
||||||
clear_after_read=False,
|
clear_after_read=False,
|
||||||
)
|
)
|
||||||
|
if pad_token_id: # ignores the cases when it is 0 or None
|
||||||
for i in range(max_length):
|
for i in range(max_length):
|
||||||
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||||
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||||
@@ -2699,9 +2705,7 @@ class TFGenerationMixin:
|
|||||||
(0, 0, cur_len - input_ids_length),
|
(0, 0, cur_len - input_ids_length),
|
||||||
(batch_size, num_beams, input_ids_length),
|
(batch_size, num_beams, input_ids_length),
|
||||||
)
|
)
|
||||||
model_inputs = self.prepare_inputs_for_generation(
|
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
|
||||||
flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs
|
|
||||||
)
|
|
||||||
model_outputs = self(
|
model_outputs = self(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
|||||||
@@ -490,8 +490,8 @@ class GenerationMixin:
|
|||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
self,
|
self,
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
pad_token_id: int,
|
pad_token_id: Optional[int],
|
||||||
eos_token_id: int,
|
eos_token_id: Optional[int],
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||||
@@ -1137,7 +1137,11 @@ class GenerationMixin:
|
|||||||
eos_token_id = self.config.decoder.eos_token_id
|
eos_token_id = self.config.decoder.eos_token_id
|
||||||
|
|
||||||
if pad_token_id is None and eos_token_id is not None:
|
if pad_token_id is None and eos_token_id is not None:
|
||||||
# special case if pad_token_id is not defined
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
pad_token_id = eos_token_id
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
|
|||||||
@@ -813,25 +813,21 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
def set_output_embeddings(self, value):
|
def set_output_embeddings(self, value):
|
||||||
self.set_input_embeddings(value)
|
self.set_input_embeddings(value)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
|
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||||
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# tests will need to be fixed after the change
|
|
||||||
|
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||||
|
|
||||||
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
|
position_ids = kwargs.get("position_ids", None)
|
||||||
# for a future PR to not change too many things for now.
|
|
||||||
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
|
|
||||||
position_ids = None
|
|
||||||
attention_mask = None
|
|
||||||
if use_xla:
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
if past is not None and attention_mask is not None:
|
|
||||||
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
|
if attention_mask is not None and position_ids is None:
|
||||||
elif attention_mask is not None:
|
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||||
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
|
if past:
|
||||||
|
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": inputs,
|
"input_ids": inputs,
|
||||||
@@ -839,6 +835,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"past": past,
|
"past": past,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
|
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
|
||||||
|
|||||||
@@ -456,7 +456,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||||
@@ -465,12 +465,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
"repetition_penalty": 1.3,
|
"repetition_penalty": 1.3,
|
||||||
}
|
}
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
"Today is a beautiful day and I am so happy to be able take part in this amazing event.",
|
"Today is a beautiful day and I am so happy to be able take part in this amazing event.",
|
||||||
"Yesterday was a very busy day for the first time since I started writing this post",
|
"Yesterday was a very interesting time for the world to see how much of this is",
|
||||||
]
|
]
|
||||||
self.assertListEqual(output_strings, expected_output_string)
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
@@ -483,7 +483,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
@@ -498,13 +498,13 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||||
with tf.device(":/CPU:0"):
|
with tf.device(":/CPU:0"):
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
"Today is a beautiful day and we will make you feel very hot/terrific in all",
|
"Today is a beautiful day and we will make you feel very hot/terrific in all your",
|
||||||
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
|
"Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
|
||||||
]
|
]
|
||||||
self.assertListEqual(output_strings, expected_output_string)
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
@@ -517,7 +517,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||||
@@ -526,37 +526,69 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
"num_beams": 2,
|
"num_beams": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||||
|
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
"Today is a beautiful day and a great day for all of us.\n\nI’m",
|
"Today is a beautiful day and a great day for all of us.\n\nI’m",
|
||||||
"Yesterday was the first day of the year for the second time in a row,",
|
"Yesterday was the first time that a person has been arrested in the United States for",
|
||||||
]
|
]
|
||||||
self.assertListEqual(output_strings, expected_output_string)
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_distilgpt2_left_padding(self):
|
||||||
|
"""Tests that the generated text is the same, regarless of left padding"""
|
||||||
|
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||||
|
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||||
|
"no_repeat_ngram_size": 2,
|
||||||
|
"do_sample": False,
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
}
|
||||||
|
expected_output_string = (
|
||||||
|
"Today is a beautiful day and I am so happy to be able take part in this amazing event."
|
||||||
|
)
|
||||||
|
|
||||||
|
sentences = ["Today is a beautiful day and"]
|
||||||
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
# using default length
|
||||||
|
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||||
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(output_strings[0], expected_output_string)
|
||||||
|
|
||||||
|
sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
|
||||||
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
# longer max length to capture the full length (remember: it is left padded)
|
||||||
|
output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
|
||||||
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(output_strings[0], expected_output_string)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_gpt2_greedy_xla(self):
|
def test_lm_generate_gpt2_greedy_xla(self):
|
||||||
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
|
|
||||||
# the underlying problem)
|
|
||||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
sentences = ["The dog"]
|
sentences = ["The dog", "The flying machine"]
|
||||||
expected_output_strings = [
|
expected_output_strings = [
|
||||||
"The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
|
"The dog was found in a field near the intersection of West and West Streets.\n\nThe",
|
||||||
|
"The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
|
||||||
]
|
]
|
||||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, do_sample=False)
|
output_ids = model.generate(**input_ids, do_sample=False)
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
self.assertListEqual(output_strings, expected_output_strings)
|
self.assertListEqual(output_strings, expected_output_strings)
|
||||||
|
|
||||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
output_ids = xla_generate(input_ids, do_sample=False)
|
output_ids = xla_generate(**input_ids, do_sample=False)
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
self.assertListEqual(output_strings, expected_output_strings)
|
self.assertListEqual(output_strings, expected_output_strings)
|
||||||
|
|
||||||
@@ -574,21 +606,24 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
sentence = ["The dog"]
|
sentence = ["The dog", "The flying machine"]
|
||||||
expected_output_string = [
|
expected_output_string = [
|
||||||
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
|
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
|
||||||
" puppies"
|
" puppies",
|
||||||
|
"The flying machine was made by an artist who found it difficult to control it as it did not use",
|
||||||
]
|
]
|
||||||
expected_output_string_xla = [
|
expected_output_string_xla = [
|
||||||
"The dog has been named in connection with the murder of a 20-year-old man in!"
|
"The dog has been named in connection with the murder of a 20-year-old man in",
|
||||||
|
"The flying machine is a new and improved system to operate and operate a new system and system "
|
||||||
|
"system system",
|
||||||
]
|
]
|
||||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
|
output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
self.assertListEqual(output_strings, expected_output_string)
|
self.assertListEqual(output_strings, expected_output_string)
|
||||||
|
|
||||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
|
output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
|
||||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||||
self.assertListEqual(output_strings, expected_output_string_xla)
|
self.assertListEqual(output_strings, expected_output_string_xla)
|
||||||
|
|||||||
Reference in New Issue
Block a user