From a5997dd81a76d669e81a42f8efafcfd1745704b9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 10 Oct 2019 11:31:01 +0200 Subject: [PATCH] better error messages --- examples/run_generation.py | 29 +++++++++++++++++++++++------ transformers/configuration_utils.py | 19 +++++++++---------- transformers/modeling_utils.py | 20 ++++++++++---------- transformers/tokenization_utils.py | 29 +++++++++++++++-------------- 4 files changed, 57 insertions(+), 40 deletions(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 5ff05f66b2..f62c3848fc 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -107,7 +107,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') return logits -def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, xlm_lang=None, device='cpu'): +def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, + is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'): context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0).repeat(num_samples, 1) generated = context @@ -125,10 +126,16 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= target_mapping[0, 0, -1] = 1.0 # predict last token inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} + if is_xlm_mlm and xlm_mask_token: + # XLM MLM models are direct models (predict same token, not next token) + # => need one additional dummy token in the input (will be masked and guessed) + input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1) + inputs = {'input_ids': input_ids} + if xlm_lang is not None: inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) - outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) + outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.) # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858) @@ -167,10 +174,7 @@ def main(): parser.add_argument('--stop_token', type=str, default=None, help="Token at which text generation is stopped") args = parser.parse_args() - if args.model_type in ["ctrl"]: - if args.temperature > 0.7 : - print('CTRL typically works better with lower temperatures (and lower top_k).') - + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() @@ -191,6 +195,10 @@ def main(): args.length = MAX_LENGTH # avoid infinite loop print(args) + if args.model_type in ["ctrl"]: + if args.temperature > 0.7 : + logger.info('CTRL typically works better with lower temperatures (and lower top_k).') + while True: xlm_lang = None # XLM Language usage detailed in the issues #1414 @@ -204,6 +212,13 @@ def main(): language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ") xlm_lang = tokenizer.lang2id[language] + # XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence) + is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path + if is_xlm_mlm: + xlm_mask_token = tokenizer.mask_token_id + else: + xlm_mask_token = None + raw_text = args.prompt if args.prompt else input("Model prompt >>> ") if args.model_type in ["transfo-xl", "xlnet"]: # Models with memory likes to have a long prompt for short inputs. @@ -218,6 +233,8 @@ def main(): top_p=args.top_p, repetition_penalty=args.repetition_penalty, is_xlnet=bool(args.model_type == "xlnet"), + is_xlm_mlm=is_xlm_mlm, + xlm_mask_token=xlm_mask_token, xlm_lang=xlm_lang, device=args.device, ) diff --git a/transformers/configuration_utils.py b/transformers/configuration_utils.py index 8a23be4ff6..112b15190f 100644 --- a/transformers/configuration_utils.py +++ b/transformers/configuration_utils.py @@ -130,20 +130,19 @@ class PretrainedConfig(object): # redirect to the cache, if necessary try: resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError as e: + except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_config_archive_map: - logger.error( - "Couldn't reach server at '{}' to download pretrained model configuration file.".format( - config_file)) + msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format( + config_file) else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( + msg = "Model name '{}' was not found in model name list ({}). " \ + "We assumed '{}' was a path or url to a configuration file named {} or " \ + "a directory containing such a file but couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_config_archive_map.keys()), - config_file)) - raise e + config_file, CONFIG_NAME) + raise EnvironmentError(msg) + if resolved_config_file == config_file: logger.info("loading configuration file {}".format(config_file)) else: diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 84b64e3ca4..d082137d5d 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -316,20 +316,20 @@ class PreTrainedModel(nn.Module): # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError as e: + except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: - logger.error( - "Couldn't reach server at '{}' to download pretrained weights.".format( - archive_file)) + msg = "Couldn't reach server at '{}' to download pretrained weights.".format( + archive_file) else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( + msg = "Model name '{}' was not found in model name list ({}). " \ + "We assumed '{}' was a path or url to model weight files named one of {} but " \ + "couldn't find any such file at this path or url.".format( pretrained_model_name_or_path, ', '.join(cls.pretrained_model_archive_map.keys()), - archive_file)) - raise e + archive_file, + [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) + raise EnvironmentError(msg) + if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 313547a533..5e5be872ef 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -337,13 +337,13 @@ class PreTrainedTokenizer(object): vocab_files[file_id] = full_file_name if all(full_file_name is None for full_file_name in vocab_files.values()): - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find tokenizer files" - "at this path or url.".format( + raise EnvironmentError( + "Model name '{}' was not found in tokenizers model name list ({}). " + "We assumed '{}' was a path or url to a directory containing vocabulary files " + "named {} but couldn't find such vocabulary files at this path or url.".format( pretrained_model_name_or_path, ', '.join(s3_models), - pretrained_model_name_or_path, )) - return None + pretrained_model_name_or_path, + list(cls.vocab_files_names.values()))) # Get files from url, cache, or disk depending on the case try: @@ -353,17 +353,18 @@ class PreTrainedTokenizer(object): resolved_vocab_files[file_id] = None else: resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError as e: + except EnvironmentError: if pretrained_model_name_or_path in s3_models: - logger.error("Couldn't reach server to download vocabulary.") + msg = "Couldn't reach server at '{}' to download vocabulary files." else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} " - "at this path or url.".format( + msg = "Model name '{}' was not found in tokenizers model name list ({}). " \ + "We assumed '{}' was a path or url to a directory containing vocabulary files " \ + "named {}, but couldn't find such vocabulary files at this path or url.".format( pretrained_model_name_or_path, ', '.join(s3_models), - pretrained_model_name_or_path, str(vocab_files.keys()))) - raise e + pretrained_model_name_or_path, + list(cls.vocab_files_names.values())) + + raise EnvironmentError(msg) for file_id, file_path in vocab_files.items(): if file_path == resolved_vocab_files[file_id]: