python 2 compatibility

This commit is contained in:
thomwolf
2019-02-06 00:07:46 +01:00
parent ba37ddc5ce
commit 448937c00d
17 changed files with 246 additions and 184 deletions

View File

@@ -24,6 +24,8 @@ import os
import shutil
import tarfile
import tempfile
import sys
from io import open
import torch
import torch.nn as nn
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = inputs_embeds + position_embeds + token_type_embeds
for block in self.h:
hidden_states = block(hidden_states)
return hidden_states.view(*input_shape, hidden_states.size(-1))
output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):