python 2 compatibility

This commit is contained in:
thomwolf
2019-02-06 00:07:46 +01:00
parent ba37ddc5ce
commit 448937c00d
17 changed files with 246 additions and 184 deletions

View File

@@ -15,18 +15,18 @@
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import copy
import json
import math
import logging
import math
import os
import shutil
import tarfile
import tempfile
import shutil
import sys
from io import open
import torch
from torch import nn
@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
import re
import numpy as np
import tensorflow as tf
except ModuleNotFoundError:
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
@@ -164,7 +164,8 @@ class BertConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "