adding Transformer XL
This commit is contained in:
125
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
Executable file
125
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
Executable file
@@ -0,0 +1,125 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HugginFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert OpenAI GPT checkpoint."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import argparse
|
||||||
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||||
|
transfo_xl_config_file,
|
||||||
|
pytorch_dump_folder_path):
|
||||||
|
config_path = os.path.abspath(transfo_xl_config_file)
|
||||||
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
|
|
||||||
|
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||||
|
# Load weights from TF model
|
||||||
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
|
names = []
|
||||||
|
arrays = []
|
||||||
|
for name, shape in init_vars:
|
||||||
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||||
|
array = tf.train.load_variable(tf_path, name)
|
||||||
|
names.append(name)
|
||||||
|
arrays.append(array)
|
||||||
|
|
||||||
|
# Initialise PyTorch model
|
||||||
|
# Construct model
|
||||||
|
if transfo_xl_config_file == "":
|
||||||
|
config = TransfoXLConfig()
|
||||||
|
else:
|
||||||
|
config = TransfoXLConfig(transfo_xl_config_file)
|
||||||
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
|
model = TransfoXLModel(config)
|
||||||
|
|
||||||
|
for name, array in zip(names, arrays):
|
||||||
|
name = name.split('/')
|
||||||
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
|
# which are not required for using pretrained model
|
||||||
|
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||||
|
print("Skipping {}".format("/".join(name)))
|
||||||
|
continue
|
||||||
|
pointer = model
|
||||||
|
for m_name in name:
|
||||||
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||||
|
l = re.split(r'_(\d+)', m_name)
|
||||||
|
else:
|
||||||
|
l = [m_name]
|
||||||
|
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||||
|
pointer = getattr(pointer, 'bias')
|
||||||
|
elif l[0] == 'output_weights':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
else:
|
||||||
|
pointer = getattr(pointer, l[0])
|
||||||
|
if len(l) >= 2:
|
||||||
|
num = int(l[1])
|
||||||
|
pointer = pointer[num]
|
||||||
|
if m_name[-11:] == '_embeddings':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif m_name == 'kernel':
|
||||||
|
array = np.transpose(array)
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||||
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
|
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||||
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
## Required parameters
|
||||||
|
parser.add_argument("--tf_checkpoint_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path the TensorFlow checkpoint path.")
|
||||||
|
parser.add_argument("--transfo_xl_config_file",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
parser.add_argument("--pytorch_dump_folder_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path to the output PyTorch model.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||||
|
args.transfo_xl_config_file,
|
||||||
|
args.pytorch_dump_folder_path)
|
||||||
@@ -1,3 +1,20 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The OpenAI Team Authors and HugginFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch OpenAI GPT model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
|||||||
1432
pytorch_pretrained_bert/modeling_transfo_xl.py
Normal file
1432
pytorch_pretrained_bert/modeling_transfo_xl.py
Normal file
File diff suppressed because it is too large
Load Diff
314
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
Normal file
314
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Utilities for PyTorch Transformer XL model.
|
||||||
|
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
|
||||||
|
# CUDA_MINOR = int(torch.version.cuda.split('.')[1])
|
||||||
|
|
||||||
|
class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||||
|
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||||
|
keep_order=False):
|
||||||
|
super(ProjectedAdaptiveLogSoftmax, self).__init__()
|
||||||
|
|
||||||
|
self.n_token = n_token
|
||||||
|
self.d_embed = d_embed
|
||||||
|
self.d_proj = d_proj
|
||||||
|
|
||||||
|
self.cutoffs = cutoffs + [n_token]
|
||||||
|
self.cutoff_ends = [0] + self.cutoffs
|
||||||
|
self.div_val = div_val
|
||||||
|
|
||||||
|
self.shortlist_size = self.cutoffs[0]
|
||||||
|
self.n_clusters = len(self.cutoffs) - 1
|
||||||
|
self.head_size = self.shortlist_size + self.n_clusters
|
||||||
|
|
||||||
|
if self.n_clusters > 0:
|
||||||
|
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
|
||||||
|
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
|
||||||
|
|
||||||
|
self.out_layers = nn.ModuleList()
|
||||||
|
self.out_projs = nn.ParameterList()
|
||||||
|
|
||||||
|
if div_val == 1:
|
||||||
|
for i in range(len(self.cutoffs)):
|
||||||
|
if d_proj != d_embed:
|
||||||
|
self.out_projs.append(
|
||||||
|
nn.Parameter(torch.Tensor(d_proj, d_embed))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.out_projs.append(None)
|
||||||
|
|
||||||
|
self.out_layers.append(nn.Linear(d_embed, n_token))
|
||||||
|
else:
|
||||||
|
for i in range(len(self.cutoffs)):
|
||||||
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
||||||
|
d_emb_i = d_embed // (div_val ** i)
|
||||||
|
|
||||||
|
self.out_projs.append(
|
||||||
|
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
|
||||||
|
|
||||||
|
self.keep_order = keep_order
|
||||||
|
|
||||||
|
def _compute_logit(self, hidden, weight, bias, proj):
|
||||||
|
if proj is None:
|
||||||
|
logit = F.linear(hidden, weight, bias=bias)
|
||||||
|
else:
|
||||||
|
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
|
||||||
|
proj_hid = F.linear(hidden, proj.t().contiguous())
|
||||||
|
logit = F.linear(proj_hid, weight, bias=bias)
|
||||||
|
# else:
|
||||||
|
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
|
||||||
|
# if bias is not None:
|
||||||
|
# logit = logit + bias
|
||||||
|
|
||||||
|
return logit
|
||||||
|
|
||||||
|
def forward(self, hidden, target, keep_order=False):
|
||||||
|
'''
|
||||||
|
hidden :: [len*bsz x d_proj]
|
||||||
|
target :: [len*bsz]
|
||||||
|
'''
|
||||||
|
|
||||||
|
if hidden.size(0) != target.size(0):
|
||||||
|
raise RuntimeError('Input and target should have the same size '
|
||||||
|
'in the batch dimension.')
|
||||||
|
|
||||||
|
if self.n_clusters == 0:
|
||||||
|
logit = self._compute_logit(hidden, self.out_layers[0].weight,
|
||||||
|
self.out_layers[0].bias, self.out_projs[0])
|
||||||
|
nll = -F.log_softmax(logit, dim=-1) \
|
||||||
|
.gather(1, target.unsqueeze(1)).squeeze(1)
|
||||||
|
else:
|
||||||
|
# construct weights and biases
|
||||||
|
weights, biases = [], []
|
||||||
|
for i in range(len(self.cutoffs)):
|
||||||
|
if self.div_val == 1:
|
||||||
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||||
|
weight_i = self.out_layers[0].weight[l_idx:r_idx]
|
||||||
|
bias_i = self.out_layers[0].bias[l_idx:r_idx]
|
||||||
|
else:
|
||||||
|
weight_i = self.out_layers[i].weight
|
||||||
|
bias_i = self.out_layers[i].bias
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
weight_i = torch.cat(
|
||||||
|
[weight_i, self.cluster_weight], dim=0)
|
||||||
|
bias_i = torch.cat(
|
||||||
|
[bias_i, self.cluster_bias], dim=0)
|
||||||
|
|
||||||
|
weights.append(weight_i)
|
||||||
|
biases.append(bias_i)
|
||||||
|
|
||||||
|
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
|
||||||
|
|
||||||
|
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
|
||||||
|
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||||
|
|
||||||
|
nll = torch.zeros_like(target,
|
||||||
|
dtype=hidden.dtype, device=hidden.device)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
cutoff_values = [0] + self.cutoffs
|
||||||
|
for i in range(len(cutoff_values) - 1):
|
||||||
|
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
|
||||||
|
|
||||||
|
mask_i = (target >= l_idx) & (target < r_idx)
|
||||||
|
indices_i = mask_i.nonzero().squeeze()
|
||||||
|
|
||||||
|
if indices_i.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_i = target.index_select(0, indices_i) - l_idx
|
||||||
|
head_logprob_i = head_logprob.index_select(0, indices_i)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||||
|
else:
|
||||||
|
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
|
||||||
|
|
||||||
|
hidden_i = hidden.index_select(0, indices_i)
|
||||||
|
|
||||||
|
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
|
||||||
|
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
||||||
|
|
||||||
|
logprob_i = head_logprob_i[:, -i] \
|
||||||
|
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||||
|
|
||||||
|
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
||||||
|
nll.index_copy_(0, indices_i, -logprob_i)
|
||||||
|
else:
|
||||||
|
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
|
||||||
|
|
||||||
|
offset += logprob_i.size(0)
|
||||||
|
|
||||||
|
return nll
|
||||||
|
|
||||||
|
class LogUniformSampler(object):
|
||||||
|
def __init__(self, range_max, n_sample):
|
||||||
|
"""
|
||||||
|
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||||
|
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||||
|
|
||||||
|
expected count can be approximated by 1 - (1 - p)^n
|
||||||
|
and we use a numerically stable version -expm1(num_tries * log1p(-p))
|
||||||
|
|
||||||
|
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
self.range_max = range_max
|
||||||
|
log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||||
|
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||||
|
# print('P', self.dist.numpy().tolist()[-30:])
|
||||||
|
|
||||||
|
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
|
||||||
|
|
||||||
|
self.n_sample = n_sample
|
||||||
|
|
||||||
|
def sample(self, labels):
|
||||||
|
"""
|
||||||
|
labels: [b1, b2]
|
||||||
|
Return
|
||||||
|
true_log_probs: [b1, b2]
|
||||||
|
samp_log_probs: [n_sample]
|
||||||
|
neg_samples: [n_sample]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# neg_samples = torch.empty(0).long()
|
||||||
|
n_sample = self.n_sample
|
||||||
|
n_tries = 2 * n_sample
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
||||||
|
device = labels.device
|
||||||
|
neg_samples = neg_samples.to(device)
|
||||||
|
true_log_probs = self.log_q[labels].to(device)
|
||||||
|
samp_log_probs = self.log_q[neg_samples].to(device)
|
||||||
|
return true_log_probs, samp_log_probs, neg_samples
|
||||||
|
|
||||||
|
def sample_logits(embedding, bias, labels, inputs, sampler):
|
||||||
|
"""
|
||||||
|
embedding: an nn.Embedding layer
|
||||||
|
bias: [n_vocab]
|
||||||
|
labels: [b1, b2]
|
||||||
|
inputs: [b1, b2, n_emb]
|
||||||
|
sampler: you may use a LogUniformSampler
|
||||||
|
Return
|
||||||
|
logits: [b1, b2, 1 + n_sample]
|
||||||
|
"""
|
||||||
|
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
|
||||||
|
n_sample = neg_samples.size(0)
|
||||||
|
b1, b2 = labels.size(0), labels.size(1)
|
||||||
|
all_ids = torch.cat([labels.view(-1), neg_samples])
|
||||||
|
all_w = embedding(all_ids)
|
||||||
|
true_w = all_w[: -n_sample].view(b1, b2, -1)
|
||||||
|
sample_w = all_w[- n_sample:].view(n_sample, -1)
|
||||||
|
|
||||||
|
all_b = bias[all_ids]
|
||||||
|
true_b = all_b[: -n_sample].view(b1, b2)
|
||||||
|
sample_b = all_b[- n_sample:]
|
||||||
|
|
||||||
|
hit = (labels[:, :, None] == neg_samples).detach()
|
||||||
|
|
||||||
|
true_logits = torch.einsum('ijk,ijk->ij',
|
||||||
|
[true_w, inputs]) + true_b - true_log_probs
|
||||||
|
sample_logits = torch.einsum('lk,ijk->ijl',
|
||||||
|
[sample_w, inputs]) + sample_b - samp_log_probs
|
||||||
|
sample_logits.masked_fill_(hit, -1e30)
|
||||||
|
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
# class LogUniformSampler(object):
|
||||||
|
# def __init__(self, range_max, unique=False):
|
||||||
|
# """
|
||||||
|
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||||
|
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||||
|
# """
|
||||||
|
# self.range_max = range_max
|
||||||
|
# log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||||
|
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||||
|
|
||||||
|
# self.unique = unique
|
||||||
|
|
||||||
|
# if self.unique:
|
||||||
|
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
|
||||||
|
|
||||||
|
# def sample(self, n_sample, labels):
|
||||||
|
# pos_sample, new_labels = labels.unique(return_inverse=True)
|
||||||
|
# n_pos_sample = pos_sample.size(0)
|
||||||
|
# n_neg_sample = n_sample - n_pos_sample
|
||||||
|
|
||||||
|
# if self.unique:
|
||||||
|
# self.exclude_mask.index_fill_(0, pos_sample, 1)
|
||||||
|
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
|
||||||
|
# self.exclude_mask.index_fill_(0, pos_sample, 0)
|
||||||
|
# else:
|
||||||
|
# sample_dist = self.dist
|
||||||
|
|
||||||
|
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
|
||||||
|
|
||||||
|
# sample = torch.cat([pos_sample, neg_sample])
|
||||||
|
# sample_prob = self.dist[sample]
|
||||||
|
|
||||||
|
# return new_labels, sample, sample_prob
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
S, B = 3, 4
|
||||||
|
n_vocab = 10000
|
||||||
|
n_sample = 5
|
||||||
|
H = 32
|
||||||
|
|
||||||
|
labels = torch.LongTensor(S, B).random_(0, n_vocab)
|
||||||
|
|
||||||
|
# sampler = LogUniformSampler(n_vocab, unique=False)
|
||||||
|
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
||||||
|
|
||||||
|
sampler = LogUniformSampler(n_vocab, unique=True)
|
||||||
|
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
||||||
|
|
||||||
|
# print('true_probs', true_probs.numpy().tolist())
|
||||||
|
# print('samp_probs', samp_probs.numpy().tolist())
|
||||||
|
# print('neg_samples', neg_samples.numpy().tolist())
|
||||||
|
|
||||||
|
# print('sum', torch.sum(sampler.dist).item())
|
||||||
|
|
||||||
|
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
|
||||||
|
|
||||||
|
embedding = nn.Embedding(n_vocab, H)
|
||||||
|
bias = torch.zeros(n_vocab)
|
||||||
|
inputs = torch.Tensor(S, B, H).normal_()
|
||||||
|
|
||||||
|
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
|
||||||
|
print('logits', logits.detach().numpy().tolist())
|
||||||
|
print('logits shape', logits.size())
|
||||||
|
print('out_labels', out_labels.detach().numpy().tolist())
|
||||||
|
print('out_labels shape', out_labels.size())
|
||||||
|
|
||||||
508
pytorch_pretrained_bert/tokenization_transfo_xl.py
Normal file
508
pytorch_pretrained_bert/tokenization_transfo_xl.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Tokenization classes for Transformer XL model.
|
||||||
|
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from collections import Counter, OrderedDict
|
||||||
|
|
||||||
|
from .file_utils import cached_path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||||
|
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
|
||||||
|
}
|
||||||
|
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||||
|
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
||||||
|
}
|
||||||
|
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||||
|
'openai-gpt': 512,
|
||||||
|
}
|
||||||
|
VOCAB_NAME = 'vocab.json'
|
||||||
|
MERGES_NAME = 'merges.txt'
|
||||||
|
|
||||||
|
class TransfoXLTokenizer(object):
|
||||||
|
"""
|
||||||
|
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
|
"""
|
||||||
|
Instantiate a TransfoXLTokenizer.
|
||||||
|
Download and cache the vocabulary if needed.
|
||||||
|
"""
|
||||||
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||||
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
else:
|
||||||
|
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||||
|
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||||
|
# redirect to the cache, if necessary
|
||||||
|
try:
|
||||||
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
|
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(
|
||||||
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
|
"at this path or url.".format(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
vocab_file, merges_file))
|
||||||
|
return None
|
||||||
|
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||||
|
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||||
|
logger.info("loading merges file {}".format(merges_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
|
vocab_file, resolved_vocab_file))
|
||||||
|
logger.info("loading merges file {} from cache at {}".format(
|
||||||
|
merges_file, resolved_merges_file))
|
||||||
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||||
|
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||||
|
# than the number of positional embeddings
|
||||||
|
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||||
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||||
|
# Instantiate tokenizer.
|
||||||
|
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||||
|
delimiter=None, vocab_file=None):
|
||||||
|
self.counter = Counter()
|
||||||
|
self.special = special
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.max_size = max_size
|
||||||
|
self.lower_case = lower_case
|
||||||
|
self.delimiter = delimiter
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
|
||||||
|
def count_file(self, path, verbose=False, add_eos=False):
|
||||||
|
if verbose: print('counting file {} ...'.format(path))
|
||||||
|
assert os.path.exists(path)
|
||||||
|
|
||||||
|
sents = []
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
for idx, line in enumerate(f):
|
||||||
|
if verbose and idx > 0 and idx % 500000 == 0:
|
||||||
|
print(' line {}'.format(idx))
|
||||||
|
symbols = self.tokenize(line, add_eos=add_eos)
|
||||||
|
self.counter.update(symbols)
|
||||||
|
sents.append(symbols)
|
||||||
|
|
||||||
|
return sents
|
||||||
|
|
||||||
|
def count_sents(self, sents, verbose=False):
|
||||||
|
"""
|
||||||
|
sents : a list of sentences, each a list of tokenized symbols
|
||||||
|
"""
|
||||||
|
if verbose: print('counting {} sents ...'.format(len(sents)))
|
||||||
|
for idx, symbols in enumerate(sents):
|
||||||
|
if verbose and idx > 0 and idx % 500000 == 0:
|
||||||
|
print(' line {}'.format(idx))
|
||||||
|
self.counter.update(symbols)
|
||||||
|
|
||||||
|
def _build_from_file(self, vocab_file):
|
||||||
|
self.idx2sym = []
|
||||||
|
self.sym2idx = OrderedDict()
|
||||||
|
|
||||||
|
with open(vocab_file, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
symb = line.strip().split()[0]
|
||||||
|
self.add_symbol(symb)
|
||||||
|
self.unk_idx = self.sym2idx['<UNK>']
|
||||||
|
|
||||||
|
def build_vocab(self):
|
||||||
|
if self.vocab_file:
|
||||||
|
print('building vocab from {}'.format(self.vocab_file))
|
||||||
|
self._build_from_file(self.vocab_file)
|
||||||
|
print('final vocab size {}'.format(len(self)))
|
||||||
|
else:
|
||||||
|
print('building vocab with min_freq={}, max_size={}'.format(
|
||||||
|
self.min_freq, self.max_size))
|
||||||
|
self.idx2sym = []
|
||||||
|
self.sym2idx = OrderedDict()
|
||||||
|
|
||||||
|
for sym in self.special:
|
||||||
|
self.add_special(sym)
|
||||||
|
|
||||||
|
for sym, cnt in self.counter.most_common(self.max_size):
|
||||||
|
if cnt < self.min_freq: break
|
||||||
|
self.add_symbol(sym)
|
||||||
|
|
||||||
|
print('final vocab size {} from {} unique tokens'.format(
|
||||||
|
len(self), len(self.counter)))
|
||||||
|
|
||||||
|
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
|
||||||
|
add_double_eos=False):
|
||||||
|
if verbose: print('encoding file {} ...'.format(path))
|
||||||
|
assert os.path.exists(path)
|
||||||
|
encoded = []
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
for idx, line in enumerate(f):
|
||||||
|
if verbose and idx > 0 and idx % 500000 == 0:
|
||||||
|
print(' line {}'.format(idx))
|
||||||
|
symbols = self.tokenize(line, add_eos=add_eos,
|
||||||
|
add_double_eos=add_double_eos)
|
||||||
|
encoded.append(self.convert_to_tensor(symbols))
|
||||||
|
|
||||||
|
if ordered:
|
||||||
|
encoded = torch.cat(encoded)
|
||||||
|
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||||
|
if verbose: print('encoding {} sents ...'.format(len(sents)))
|
||||||
|
encoded = []
|
||||||
|
for idx, symbols in enumerate(sents):
|
||||||
|
if verbose and idx > 0 and idx % 500000 == 0:
|
||||||
|
print(' line {}'.format(idx))
|
||||||
|
encoded.append(self.convert_to_tensor(symbols))
|
||||||
|
|
||||||
|
if ordered:
|
||||||
|
encoded = torch.cat(encoded)
|
||||||
|
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def add_special(self, sym):
|
||||||
|
if sym not in self.sym2idx:
|
||||||
|
self.idx2sym.append(sym)
|
||||||
|
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||||
|
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
|
||||||
|
|
||||||
|
def add_symbol(self, sym):
|
||||||
|
if sym not in self.sym2idx:
|
||||||
|
self.idx2sym.append(sym)
|
||||||
|
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||||
|
|
||||||
|
def get_sym(self, idx):
|
||||||
|
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
|
||||||
|
return self.idx2sym[idx]
|
||||||
|
|
||||||
|
def get_idx(self, sym):
|
||||||
|
if sym in self.sym2idx:
|
||||||
|
return self.sym2idx[sym]
|
||||||
|
else:
|
||||||
|
# print('encounter unk {}'.format(sym))
|
||||||
|
assert '<eos>' not in sym
|
||||||
|
assert hasattr(self, 'unk_idx')
|
||||||
|
return self.sym2idx.get(sym, self.unk_idx)
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, indices):
|
||||||
|
"""Converts a sequence of indices in symbols using the vocab."""
|
||||||
|
return [self.get_sym(idx) for idx in indices]
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, symbols):
|
||||||
|
"""Converts a sequence of symbols into ids using the vocab."""
|
||||||
|
return [self.get_idx(sym) for sym in symbols]
|
||||||
|
|
||||||
|
def convert_to_tensor(self, symbols):
|
||||||
|
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
|
||||||
|
|
||||||
|
def decode(self, indices, exclude=None):
|
||||||
|
"""Converts a sequence of indices in a string."""
|
||||||
|
if exclude is None:
|
||||||
|
return ' '.join([self.get_sym(idx) for idx in indices])
|
||||||
|
else:
|
||||||
|
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.idx2sym)
|
||||||
|
|
||||||
|
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||||
|
line = line.strip()
|
||||||
|
# convert to lower case
|
||||||
|
if self.lower_case:
|
||||||
|
line = line.lower()
|
||||||
|
|
||||||
|
# empty delimiter '' will evaluate False
|
||||||
|
if self.delimiter == '':
|
||||||
|
symbols = line
|
||||||
|
else:
|
||||||
|
symbols = line.split(self.delimiter)
|
||||||
|
|
||||||
|
if add_double_eos: # lm1b
|
||||||
|
return ['<S>'] + symbols + ['<S>']
|
||||||
|
elif add_eos:
|
||||||
|
return symbols + ['<eos>']
|
||||||
|
else:
|
||||||
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
|
class LMOrderedIterator(object):
|
||||||
|
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
|
||||||
|
"""
|
||||||
|
data -- LongTensor -- the LongTensor is strictly ordered
|
||||||
|
"""
|
||||||
|
self.bsz = bsz
|
||||||
|
self.bptt = bptt
|
||||||
|
self.ext_len = ext_len if ext_len is not None else 0
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Work out how cleanly we can divide the dataset into bsz parts.
|
||||||
|
self.n_step = data.size(0) // bsz
|
||||||
|
|
||||||
|
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
||||||
|
data = data.narrow(0, 0, self.n_step * bsz)
|
||||||
|
|
||||||
|
# Evenly divide the data across the bsz batches.
|
||||||
|
self.data = data.view(bsz, -1).t().contiguous().to(device)
|
||||||
|
|
||||||
|
# Number of mini-batches
|
||||||
|
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
|
||||||
|
|
||||||
|
def get_batch(self, i, bptt=None):
|
||||||
|
if bptt is None: bptt = self.bptt
|
||||||
|
seq_len = min(bptt, self.data.size(0) - 1 - i)
|
||||||
|
|
||||||
|
end_idx = i + seq_len
|
||||||
|
beg_idx = max(0, i - self.ext_len)
|
||||||
|
|
||||||
|
data = self.data[beg_idx:end_idx]
|
||||||
|
target = self.data[i+1:i+1+seq_len]
|
||||||
|
|
||||||
|
return data, target, seq_len
|
||||||
|
|
||||||
|
def get_fixlen_iter(self, start=0):
|
||||||
|
for i in range(start, self.data.size(0) - 1, self.bptt):
|
||||||
|
yield self.get_batch(i)
|
||||||
|
|
||||||
|
def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
|
||||||
|
max_len = self.bptt + max_deviation * std
|
||||||
|
i = start
|
||||||
|
while True:
|
||||||
|
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
|
||||||
|
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
|
||||||
|
data, target, seq_len = self.get_batch(i, bptt)
|
||||||
|
i += seq_len
|
||||||
|
yield data, target, seq_len
|
||||||
|
if i >= self.data.size(0) - 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self.get_fixlen_iter()
|
||||||
|
|
||||||
|
|
||||||
|
class LMShuffledIterator(object):
|
||||||
|
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
|
||||||
|
"""
|
||||||
|
data -- list[LongTensor] -- there is no order among the LongTensors
|
||||||
|
"""
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
self.bsz = bsz
|
||||||
|
self.bptt = bptt
|
||||||
|
self.ext_len = ext_len if ext_len is not None else 0
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
def get_sent_stream(self):
|
||||||
|
# index iterator
|
||||||
|
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
|
||||||
|
else np.array(range(len(self.data)))
|
||||||
|
|
||||||
|
# sentence iterator
|
||||||
|
for idx in epoch_indices:
|
||||||
|
yield self.data[idx]
|
||||||
|
|
||||||
|
def stream_iterator(self, sent_stream):
|
||||||
|
# streams for each data in the batch
|
||||||
|
streams = [None] * self.bsz
|
||||||
|
|
||||||
|
data = torch.LongTensor(self.bptt, self.bsz)
|
||||||
|
target = torch.LongTensor(self.bptt, self.bsz)
|
||||||
|
|
||||||
|
n_retain = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# data : [n_retain+bptt x bsz]
|
||||||
|
# target : [bptt x bsz]
|
||||||
|
data[n_retain:].fill_(-1)
|
||||||
|
target.fill_(-1)
|
||||||
|
|
||||||
|
valid_batch = True
|
||||||
|
|
||||||
|
for i in range(self.bsz):
|
||||||
|
n_filled = 0
|
||||||
|
try:
|
||||||
|
while n_filled < self.bptt:
|
||||||
|
if streams[i] is None or len(streams[i]) <= 1:
|
||||||
|
streams[i] = next(sent_stream)
|
||||||
|
# number of new tokens to fill in
|
||||||
|
n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
|
||||||
|
# first n_retain tokens are retained from last batch
|
||||||
|
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
|
||||||
|
streams[i][:n_new]
|
||||||
|
target[n_filled:n_filled+n_new, i] = \
|
||||||
|
streams[i][1:n_new+1]
|
||||||
|
streams[i] = streams[i][n_new:]
|
||||||
|
n_filled += n_new
|
||||||
|
except StopIteration:
|
||||||
|
valid_batch = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not valid_batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = data.to(self.device)
|
||||||
|
target = target.to(self.device)
|
||||||
|
|
||||||
|
yield data, target, self.bptt
|
||||||
|
|
||||||
|
n_retain = min(data.size(0), self.ext_len)
|
||||||
|
if n_retain > 0:
|
||||||
|
data[:n_retain] = data[-n_retain:]
|
||||||
|
data.resize_(n_retain + self.bptt, data.size(1))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# sent_stream is an iterator
|
||||||
|
sent_stream = self.get_sent_stream()
|
||||||
|
|
||||||
|
for batch in self.stream_iterator(sent_stream):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
class LMMultiFileIterator(LMShuffledIterator):
|
||||||
|
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
|
||||||
|
shuffle=False):
|
||||||
|
|
||||||
|
self.paths = paths
|
||||||
|
self.vocab = vocab
|
||||||
|
|
||||||
|
self.bsz = bsz
|
||||||
|
self.bptt = bptt
|
||||||
|
self.ext_len = ext_len if ext_len is not None else 0
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
def get_sent_stream(self, path):
|
||||||
|
sents = self.vocab.encode_file(path, add_double_eos=True)
|
||||||
|
if self.shuffle:
|
||||||
|
np.random.shuffle(sents)
|
||||||
|
sent_stream = iter(sents)
|
||||||
|
|
||||||
|
return sent_stream
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.shuffle:
|
||||||
|
np.random.shuffle(self.paths)
|
||||||
|
|
||||||
|
for path in self.paths:
|
||||||
|
# sent_stream is an iterator
|
||||||
|
sent_stream = self.get_sent_stream(path)
|
||||||
|
for batch in self.stream_iterator(sent_stream):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
class Corpus(object):
|
||||||
|
def __init__(self, path, dataset, *args, **kwargs):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.vocab = Vocab(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
|
||||||
|
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||||
|
self.vocab.count_file(os.path.join(path, 'valid.txt'))
|
||||||
|
self.vocab.count_file(os.path.join(path, 'test.txt'))
|
||||||
|
elif self.dataset == 'wt103':
|
||||||
|
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||||
|
elif self.dataset == 'lm1b':
|
||||||
|
train_path_pattern = os.path.join(
|
||||||
|
path, '1-billion-word-language-modeling-benchmark-r13output',
|
||||||
|
'training-monolingual.tokenized.shuffled', 'news.en-*')
|
||||||
|
train_paths = glob.glob(train_path_pattern)
|
||||||
|
# the vocab will load from file when build_vocab() is called
|
||||||
|
|
||||||
|
self.vocab.build_vocab()
|
||||||
|
|
||||||
|
if self.dataset in ['ptb', 'wt2', 'wt103']:
|
||||||
|
self.train = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'train.txt'), ordered=True)
|
||||||
|
self.valid = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'valid.txt'), ordered=True)
|
||||||
|
self.test = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'test.txt'), ordered=True)
|
||||||
|
elif self.dataset in ['enwik8', 'text8']:
|
||||||
|
self.train = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
|
||||||
|
self.valid = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
|
||||||
|
self.test = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
|
||||||
|
elif self.dataset == 'lm1b':
|
||||||
|
self.train = train_paths
|
||||||
|
self.valid = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
|
||||||
|
self.test = self.vocab.encode_file(
|
||||||
|
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
|
||||||
|
|
||||||
|
def get_iterator(self, split, *args, **kwargs):
|
||||||
|
if split == 'train':
|
||||||
|
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
|
||||||
|
data_iter = LMOrderedIterator(self.train, *args, **kwargs)
|
||||||
|
elif self.dataset == 'lm1b':
|
||||||
|
kwargs['shuffle'] = True
|
||||||
|
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
|
||||||
|
elif split in ['valid', 'test']:
|
||||||
|
data = self.valid if split == 'valid' else self.test
|
||||||
|
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
|
||||||
|
data_iter = LMOrderedIterator(data, *args, **kwargs)
|
||||||
|
elif self.dataset == 'lm1b':
|
||||||
|
data_iter = LMShuffledIterator(data, *args, **kwargs)
|
||||||
|
|
||||||
|
return data_iter
|
||||||
|
|
||||||
|
|
||||||
|
def get_lm_corpus(datadir, dataset):
|
||||||
|
fn = os.path.join(datadir, 'cache.pt')
|
||||||
|
fn_pickle = os.path.join(datadir, 'cache.pkl')
|
||||||
|
if os.path.exists(fn):
|
||||||
|
print('Loading cached dataset...')
|
||||||
|
corpus = torch.load(fn_pickle)
|
||||||
|
elif os.path.exists(fn):
|
||||||
|
print('Loading cached dataset from pickle...')
|
||||||
|
with open(fn, "rb") as fp:
|
||||||
|
corpus = pickle.load(fp)
|
||||||
|
else:
|
||||||
|
print('Producing dataset {}...'.format(dataset))
|
||||||
|
kwargs = {}
|
||||||
|
if dataset in ['wt103', 'wt2']:
|
||||||
|
kwargs['special'] = ['<eos>']
|
||||||
|
kwargs['lower_case'] = False
|
||||||
|
elif dataset == 'ptb':
|
||||||
|
kwargs['special'] = ['<eos>']
|
||||||
|
kwargs['lower_case'] = True
|
||||||
|
elif dataset == 'lm1b':
|
||||||
|
kwargs['special'] = []
|
||||||
|
kwargs['lower_case'] = False
|
||||||
|
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
|
||||||
|
elif dataset in ['enwik8', 'text8']:
|
||||||
|
pass
|
||||||
|
|
||||||
|
corpus = Corpus(datadir, dataset, **kwargs)
|
||||||
|
torch.save(corpus, fn)
|
||||||
|
|
||||||
|
return corpus
|
||||||
Reference in New Issue
Block a user