Fix PretrainedModel.from_pretrained not passing cache_dir forward

This commit is contained in:
Anish Moorthy
2019-07-22 15:51:51 -04:00
parent b8009cb0da
commit 490ebbdcf7

View File

@@ -373,7 +373,8 @@ class PreTrainedModel(nn.Module):
if config is None: if config is None:
config, model_kwargs = cls.config_class.from_pretrained( config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
return_unused_args=True, **kwargs cache_dir=cache_dir, return_unused_args=True,
**kwargs
) )
else: else:
model_kwargs = kwargs model_kwargs = kwargs