python 2 compatibility
This commit is contained in:
@@ -24,6 +24,8 @@ import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
for block in self.h:
|
||||
hidden_states = block(hidden_states)
|
||||
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
return hidden_states.view(*output_shape)
|
||||
|
||||
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
Reference in New Issue
Block a user