better error messages
This commit is contained in:
@@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
||||
return logits
|
||||
|
||||
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, xlm_lang=None, device='cpu'):
|
||||
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
|
||||
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
|
||||
context = torch.tensor(context, dtype=torch.long, device=device)
|
||||
context = context.unsqueeze(0).repeat(num_samples, 1)
|
||||
generated = context
|
||||
@@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
||||
target_mapping[0, 0, -1] = 1.0 # predict last token
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
|
||||
if is_xlm_mlm and xlm_mask_token:
|
||||
# XLM MLM models are direct models (predict same token, not next token)
|
||||
# => need one additional dummy token in the input (will be masked and guessed)
|
||||
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
|
||||
inputs = {'input_ids': input_ids}
|
||||
|
||||
if xlm_lang is not None:
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
|
||||
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
|
||||
|
||||
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
|
||||
@@ -167,10 +174,7 @@ def main():
|
||||
parser.add_argument('--stop_token', type=str, default=None,
|
||||
help="Token at which text generation is stopped")
|
||||
args = parser.parse_args()
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7 :
|
||||
print('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
@@ -191,6 +195,10 @@ def main():
|
||||
args.length = MAX_LENGTH # avoid infinite loop
|
||||
|
||||
print(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7 :
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
while True:
|
||||
xlm_lang = None
|
||||
# XLM Language usage detailed in the issues #1414
|
||||
@@ -204,6 +212,13 @@ def main():
|
||||
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
|
||||
xlm_lang = tokenizer.lang2id[language]
|
||||
|
||||
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
|
||||
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
|
||||
if is_xlm_mlm:
|
||||
xlm_mask_token = tokenizer.mask_token_id
|
||||
else:
|
||||
xlm_mask_token = None
|
||||
|
||||
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
|
||||
if args.model_type in ["transfo-xl", "xlnet"]:
|
||||
# Models with memory likes to have a long prompt for short inputs.
|
||||
@@ -218,6 +233,8 @@ def main():
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
is_xlnet=bool(args.model_type == "xlnet"),
|
||||
is_xlm_mlm=is_xlm_mlm,
|
||||
xlm_mask_token=xlm_mask_token,
|
||||
xlm_lang=xlm_lang,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user