Generate: generation config validation fixes in docs (#25405)

This commit is contained in:
Joao Gante
2023-08-09 13:07:11 +01:00
committed by GitHub
parent 00b93cda21
commit f456b4d10b
2 changed files with 1 additions and 7 deletions

View File

@@ -80,11 +80,9 @@ into a single instance to both extract the input features and decode the predict
... pixel_values.to(device), ... pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device), ... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings, ... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id, ... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id, ... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True, ... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]], ... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True, ... return_dict_in_generate=True,
... ) ... )
@@ -125,11 +123,9 @@ into a single instance to both extract the input features and decode the predict
... pixel_values.to(device), ... pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device), ... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings, ... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id, ... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id, ... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True, ... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]], ... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True, ... return_dict_in_generate=True,
... ) ... )
@@ -172,11 +168,9 @@ into a single instance to both extract the input features and decode the predict
... pixel_values.to(device), ... pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device), ... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings, ... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id, ... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id, ... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True, ... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]], ... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True, ... return_dict_in_generate=True,
... ) ... )

View File

@@ -597,7 +597,7 @@ class GenerationConfig(PushToHubMixin):
>>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation >>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
>>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored >>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
>>> generation_config, unused_kwargs = GenerationConfig.from_pretrained( >>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
... "gpt2", top_k=1, foo=False, return_unused_kwargs=True ... "gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True
... ) ... )
>>> generation_config.top_k >>> generation_config.top_k
1 1