Merge branch 'master' into conditional-generation

This commit is contained in:
Thomas Wolf
2019-10-30 16:40:35 +01:00
committed by GitHub
87 changed files with 5059 additions and 719 deletions

View File

@@ -53,7 +53,8 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
self.is_decoder = kwargs.pop('is_decoder', False)
@@ -131,20 +132,19 @@ class PretrainedConfig(object):
# redirect to the cache, if necessary
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError as e:
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file)
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
msg = "Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to a configuration file named {} or " \
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
raise e
config_file, CONFIG_NAME)
raise EnvironmentError(msg)
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
@@ -155,7 +155,7 @@ class PretrainedConfig(object):
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'):
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed
to_remove = []
@@ -166,7 +166,7 @@ class PretrainedConfig(object):
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config %s", config)
logger.info("Model config %s", str(config))
if return_unused_kwargs:
return config, kwargs
else: