Fix docs and bad word tokens generation_utils.py (#6387)
* fix * fix2 * fix3
This commit is contained in:
@@ -163,7 +163,7 @@ class TFGenerationMixin:
|
|||||||
model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
input_context = 'The dog'
|
input_context = 'The dog'
|
||||||
input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
|
input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
|
||||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
|
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
|
||||||
for i in range(3): # 3 output sequences were generated
|
for i in range(3): # 3 output sequences were generated
|
||||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||||
|
|
||||||
@@ -936,8 +936,8 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
|
|||||||
if len(tokens) == 0:
|
if len(tokens) == 0:
|
||||||
# if bad word tokens is just one token always ban it
|
# if bad word tokens is just one token always ban it
|
||||||
return True
|
return True
|
||||||
if len(tokens) > len(prev_input_ids):
|
if len(tokens) > len(prev_tokens):
|
||||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
# if bad word tokens are longer than prev tokens they can't be equal
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if prev_tokens[-len(tokens) :] == tokens:
|
if prev_tokens[-len(tokens) :] == tokens:
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ class GenerationMixin:
|
|||||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||||
input_context = 'The dog'
|
input_context = 'The dog'
|
||||||
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
||||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
|
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling
|
||||||
for i in range(3): # 3 output sequences were generated
|
for i in range(3): # 3 output sequences were generated
|
||||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||||
|
|
||||||
@@ -876,8 +876,8 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
|||||||
if len(tokens) == 0:
|
if len(tokens) == 0:
|
||||||
# if bad word tokens is just one token always ban it
|
# if bad word tokens is just one token always ban it
|
||||||
return True
|
return True
|
||||||
if len(tokens) > len(prev_input_ids):
|
if len(tokens) > len(prev_tokens):
|
||||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
# if bad word tokens are longer than prev tokens they can't be equal
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if prev_tokens[-len(tokens) :] == tokens:
|
if prev_tokens[-len(tokens) :] == tokens:
|
||||||
|
|||||||
Reference in New Issue
Block a user