From 9d94aecd516c7540a94b9d781ef28d7375a796bc Mon Sep 17 00:00:00 2001 From: Zhu Baohe Date: Thu, 13 Aug 2020 19:12:16 +0800 Subject: [PATCH] Fix docs and bad word tokens generation_utils.py (#6387) * fix * fix2 * fix3 --- src/transformers/generation_tf_utils.py | 6 +++--- src/transformers/generation_utils.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index f84d420284..41ef2f51e0 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -163,7 +163,7 @@ class TFGenerationMixin: model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context - outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) @@ -936,8 +936,8 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): if len(tokens) == 0: # if bad word tokens is just one token always ban it return True - if len(tokens) > len(prev_input_ids): - # if bad word tokens are longer then prev input_ids they can't be equal + if len(tokens) > len(prev_tokens): + # if bad word tokens are longer than prev tokens they can't be equal return False if prev_tokens[-len(tokens) :] == tokens: diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index c731c01fa9..6a0d7ad020 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -226,7 +226,7 @@ class GenerationMixin: model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. input_context = 'The dog' input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context - outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling + outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling for i in range(3): # 3 output sequences were generated print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) @@ -876,8 +876,8 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter if len(tokens) == 0: # if bad word tokens is just one token always ban it return True - if len(tokens) > len(prev_input_ids): - # if bad word tokens are longer then prev input_ids they can't be equal + if len(tokens) > len(prev_tokens): + # if bad word tokens are longer than prev tokens they can't be equal return False if prev_tokens[-len(tokens) :] == tokens: