adding gpt2
This commit is contained in:
@@ -13,6 +13,9 @@ from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
|||||||
load_tf_weights_in_openai_gpt)
|
load_tf_weights_in_openai_gpt)
|
||||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
||||||
load_tf_weights_in_transfo_xl)
|
load_tf_weights_in_transfo_xl)
|
||||||
|
from .modeling_gpt2 import (GPT2Config, GPT2Model,
|
||||||
|
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||||
|
load_tf_weights_in_gpt2)
|
||||||
|
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .optimization_openai import OpenAIAdam
|
from .optimization_openai import OpenAIAdam
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ def main():
|
|||||||
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
|
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
|
||||||
"convert_tf_checkpoint_to_pytorch",
|
"convert_tf_checkpoint_to_pytorch",
|
||||||
"convert_openai_checkpoint",
|
"convert_openai_checkpoint",
|
||||||
"convert_transfo_xl_checkpoint"
|
"convert_transfo_xl_checkpoint",
|
||||||
|
"convert_gpt2_checkpoint",
|
||||||
]:
|
]:
|
||||||
print(
|
print(
|
||||||
"Should be used as one of: \n"
|
"Should be used as one of: \n"
|
||||||
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
|
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
|
||||||
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
|
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
|
||||||
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
|
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
|
||||||
|
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
|
||||||
else:
|
else:
|
||||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
||||||
try:
|
try:
|
||||||
@@ -40,7 +42,7 @@ def main():
|
|||||||
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
||||||
OPENAI_GPT_CONFIG,
|
OPENAI_GPT_CONFIG,
|
||||||
PYTORCH_DUMP_OUTPUT)
|
PYTORCH_DUMP_OUTPUT)
|
||||||
else:
|
elif sys.argv[1] == "convert_transfo_xl_checkpoint":
|
||||||
try:
|
try:
|
||||||
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -61,5 +63,21 @@ def main():
|
|||||||
else:
|
else:
|
||||||
TF_CONFIG = ""
|
TF_CONFIG = ""
|
||||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
|
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
||||||
|
except ImportError:
|
||||||
|
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
TF_CHECKPOINT = sys.argv[2]
|
||||||
|
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||||
|
if len(sys.argv) == 5:
|
||||||
|
TF_CONFIG = sys.argv[4]
|
||||||
|
else:
|
||||||
|
TF_CONFIG = ""
|
||||||
|
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
72
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
Executable file
72
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
Executable file
@@ -0,0 +1,72 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HugginFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert OpenAI GPT checkpoint."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from io import open
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
|
||||||
|
GPT2Config,
|
||||||
|
GPT2Model,
|
||||||
|
load_tf_weights_in_gpt2)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
|
||||||
|
# Construct model
|
||||||
|
if gpt2_config_file == "":
|
||||||
|
config = GPT2Config()
|
||||||
|
else:
|
||||||
|
config = GPT2Config(gpt2_config_file)
|
||||||
|
model = GPT2Model(config)
|
||||||
|
|
||||||
|
# Load weights from numpy
|
||||||
|
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||||
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
|
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||||
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
## Required parameters
|
||||||
|
parser.add_argument("--gpt2_checkpoint_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path the TensorFlow checkpoint path.")
|
||||||
|
parser.add_argument("--pytorch_dump_folder_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path to the output PyTorch model.")
|
||||||
|
parser.add_argument("--gpt2_config_file",
|
||||||
|
default = "",
|
||||||
|
type = str,
|
||||||
|
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
|
||||||
|
args.gpt2_config_file,
|
||||||
|
args.pytorch_dump_folder_path)
|
||||||
681
pytorch_pretrained_bert/modeling_gpt2.py
Normal file
681
pytorch_pretrained_bert/modeling_gpt2.py
Normal file
@@ -0,0 +1,681 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The OpenAI Team Authors and HugginFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch OpenAI GPT-2 model."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tarfile
|
||||||
|
import tempfile
|
||||||
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from .file_utils import cached_path
|
||||||
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"}
|
||||||
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"}
|
||||||
|
|
||||||
|
CONFIG_NAME = "config.json"
|
||||||
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
|
||||||
|
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
|
||||||
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
except ImportError:
|
||||||
|
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
|
raise
|
||||||
|
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
||||||
|
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||||
|
# Load weights from TF model
|
||||||
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
|
names = []
|
||||||
|
arrays = []
|
||||||
|
for name, shape in init_vars:
|
||||||
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||||
|
array = tf.train.load_variable(tf_path, name)
|
||||||
|
names.append(name)
|
||||||
|
arrays.append(array)
|
||||||
|
|
||||||
|
for name, array in zip(names, arrays):
|
||||||
|
name = name.split('/')
|
||||||
|
pointer = model
|
||||||
|
for m_name in name:
|
||||||
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||||
|
l = re.split(r'_(\d+)', m_name)
|
||||||
|
else:
|
||||||
|
l = [m_name]
|
||||||
|
if l[0] == 'w' or l[0] == 'g':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif l[0] == 'b':
|
||||||
|
pointer = getattr(pointer, 'bias')
|
||||||
|
else:
|
||||||
|
pointer = getattr(pointer, l[0])
|
||||||
|
if len(l) >= 2:
|
||||||
|
num = int(l[1])
|
||||||
|
pointer = pointer[num]
|
||||||
|
if m_name[-11:] == '_embeddings':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif m_name == 'kernel':
|
||||||
|
array = np.transpose(array)
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2Config(object):
|
||||||
|
"""Configuration class to store the configuration of a `GPT2Model`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size_or_config_json_file=40478,
|
||||||
|
n_positions=1024,
|
||||||
|
n_ctx=1024,
|
||||||
|
n_embd=768,
|
||||||
|
n_layer=12,
|
||||||
|
n_head=12,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
initializer_range=0.02,
|
||||||
|
):
|
||||||
|
"""Constructs GPT2Config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
|
||||||
|
n_positions: Number of positional embeddings.
|
||||||
|
n_ctx: Size of the causal mask (usually same as n_positions).
|
||||||
|
n_embd: Dimensionality of the embeddings and hidden states.
|
||||||
|
n_layer: Number of hidden layers in the Transformer encoder.
|
||||||
|
n_head: Number of attention heads for each attention layer in
|
||||||
|
the Transformer encoder.
|
||||||
|
layer_norm_epsilon: epsilon to use in the layer norm layers
|
||||||
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
|
initializing all weight matrices.
|
||||||
|
"""
|
||||||
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
|
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||||
|
json_config = json.loads(reader.read())
|
||||||
|
for key, value in json_config.items():
|
||||||
|
self.__dict__[key] = value
|
||||||
|
elif isinstance(vocab_size_or_config_json_file, int):
|
||||||
|
self.vocab_size = vocab_size_or_config_json_file
|
||||||
|
self.n_ctx = n_ctx
|
||||||
|
self.n_positions = n_positions
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"First argument must be either a vocabulary size (int)"
|
||||||
|
"or the path to a pretrained model config file (str)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, json_object):
|
||||||
|
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
|
||||||
|
config = GPT2Config(vocab_size_or_config_json_file=-1)
|
||||||
|
for key, value in json_object.items():
|
||||||
|
config.__dict__[key] = value
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_file(cls, json_file):
|
||||||
|
"""Constructs a `GPT2Config` from a json file of parameters."""
|
||||||
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||||||
|
text = reader.read()
|
||||||
|
return cls.from_dict(json.loads(text))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.to_json_string())
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Serializes this instance to a Python dictionary."""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""Serializes this instance to a JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1D(nn.Module):
|
||||||
|
def __init__(self, nf, nx):
|
||||||
|
super(Conv1D, self).__init__()
|
||||||
|
self.nf = nf
|
||||||
|
w = torch.empty(nx, nf)
|
||||||
|
nn.init.normal_(w, std=0.02)
|
||||||
|
self.weight = Parameter(w)
|
||||||
|
self.bias = Parameter(torch.zeros(nf))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
size_out = x.size()[:-1] + (self.nf,)
|
||||||
|
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||||
|
x = x.view(*size_out)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, nx, n_ctx, config, scale=False):
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||||
|
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||||
|
assert n_state % config.n_head == 0
|
||||||
|
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||||
|
self.n_head = config.n_head
|
||||||
|
self.split_size = n_state
|
||||||
|
self.scale = scale
|
||||||
|
self.c_attn = Conv1D(n_state * 3, nx)
|
||||||
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
|
|
||||||
|
def _attn(self, q, k, v):
|
||||||
|
w = torch.matmul(q, k)
|
||||||
|
if self.scale:
|
||||||
|
w = w / math.sqrt(v.size(-1))
|
||||||
|
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
|
||||||
|
# XD: self.b may be larger than w, so we need to crop it
|
||||||
|
b = self.bias[:, :, : w.size(-2), : w.size(-1)]
|
||||||
|
w = w * b + -1e10 * (1 - b)
|
||||||
|
|
||||||
|
w = nn.Softmax(dim=-1)(w)
|
||||||
|
return torch.matmul(w, v)
|
||||||
|
|
||||||
|
def merge_heads(self, x):
|
||||||
|
x = x.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||||
|
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||||
|
|
||||||
|
def split_heads(self, x, k=False):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||||
|
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||||
|
if k:
|
||||||
|
return x.permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(self, x, past=None):
|
||||||
|
x = self.c_attn(x)
|
||||||
|
query, key, value = x.split(self.split_size, dim=2)
|
||||||
|
query = self.split_heads(query)
|
||||||
|
key = self.split_heads(key, k=True)
|
||||||
|
value = self.split_heads(value)
|
||||||
|
present = key, value
|
||||||
|
if past is not None:
|
||||||
|
past_key, past_value = past
|
||||||
|
key = torch.cat((past_key, key), dim=-2)
|
||||||
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
|
a = self._attn(query, key, value)
|
||||||
|
a = self.merge_heads(a)
|
||||||
|
a = self.c_proj(a)
|
||||||
|
return a, present
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
nx = config.n_embd
|
||||||
|
self.c_fc = Conv1D(n_state, nx)
|
||||||
|
self.c_proj = Conv1D(nx, n_state)
|
||||||
|
self.act = gelu
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.act(self.c_fc(x))
|
||||||
|
h2 = self.c_proj(h)
|
||||||
|
return h2
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, n_ctx, config, scale=False):
|
||||||
|
super(Block, self).__init__()
|
||||||
|
nx = config.n_embd
|
||||||
|
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||||
|
self.attn = Attention(nx, n_ctx, config, scale)
|
||||||
|
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||||
|
self.mlp = MLP(4 * nx, config)
|
||||||
|
|
||||||
|
def forward(self, x, past):
|
||||||
|
a, present = self.attn(self.ln_1(x), past=past)
|
||||||
|
x = x + a
|
||||||
|
m = self.mlp(self.ln_2(c))
|
||||||
|
x = x + m
|
||||||
|
return x, present
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2LMHead(nn.Module):
|
||||||
|
""" Language Model Head for the transformer """
|
||||||
|
|
||||||
|
def __init__(self, model_embeddings_weights, config):
|
||||||
|
super(GPT2LMHead, self).__init__()
|
||||||
|
self.n_embd = config.n_embd
|
||||||
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
|
|
||||||
|
def set_embeddings_weights(self, model_embeddings_weights):
|
||||||
|
embed_shape = model_embeddings_weights.shape
|
||||||
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
# Truncated Language modeling logits (we remove the last token)
|
||||||
|
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
||||||
|
lm_logits = self.decoder(hidden_state)
|
||||||
|
return lm_logits
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2MultipleChoiceHead(nn.Module):
|
||||||
|
""" Classifier Head for the transformer """
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GPT2MultipleChoiceHead, self).__init__()
|
||||||
|
self.n_embd = config.n_embd
|
||||||
|
self.linear = nn.Linear(config.n_embd, 1)
|
||||||
|
|
||||||
|
nn.init.normal_(self.linear.weight, std=0.02)
|
||||||
|
nn.init.normal_(self.linear.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, mc_token_ids):
|
||||||
|
# Classification logits
|
||||||
|
# hidden_state (bsz, num_choices, seq_length, hidden_size)
|
||||||
|
# mc_token_ids (bsz, num_choices)
|
||||||
|
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
|
||||||
|
# (bsz, num_choices, 1, hidden_size)
|
||||||
|
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
|
||||||
|
# (bsz, num_choices, hidden_size)
|
||||||
|
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
|
||||||
|
# (bsz, num_choices)
|
||||||
|
return multiple_choice_logits
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2PreTrainedModel(nn.Module):
|
||||||
|
""" An abstract class to handle weights initialization and
|
||||||
|
a simple interface for dowloading and loading pretrained models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super(GPT2PreTrainedModel, self).__init__()
|
||||||
|
if not isinstance(config, GPT2Config):
|
||||||
|
raise ValueError(
|
||||||
|
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
||||||
|
"To create a model from a pretrained model use "
|
||||||
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||||
|
self.__class__.__name__, self.__class__.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def set_tied():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_weights(self, module):
|
||||||
|
""" Initialize the weights.
|
||||||
|
"""
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
elif isinstance(module, LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
|
Download and cache the pre-trained model file if needed.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
pretrained_model_name_or_path: either:
|
||||||
|
- a str with the name of a pre-trained model to load selected in the list of:
|
||||||
|
. `openai-gpt`
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `gpt2_config.json` a configuration file for the model
|
||||||
|
. `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `bert_config.json` a configuration file for the model
|
||||||
|
. a TensorFlow checkpoint with trained weights
|
||||||
|
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||||
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||||
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||||
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
|
(ex: num_labels for BertForSequenceClassification)
|
||||||
|
"""
|
||||||
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||||
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
else:
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
|
# redirect to the cache, if necessary
|
||||||
|
try:
|
||||||
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||||
|
except EnvironmentError:
|
||||||
|
logger.error(
|
||||||
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
|
"at this path or url.".format(
|
||||||
|
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
|
||||||
|
archive_file, config_file
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if resolved_archive_file == archive_file and resolved_config_file == config_file:
|
||||||
|
logger.info("loading weights file {}".format(archive_file))
|
||||||
|
logger.info("loading configuration file {}".format(config_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading weights file {} from cache at {}".format(
|
||||||
|
archive_file, resolved_archive_file))
|
||||||
|
logger.info("loading configuration file {} from cache at {}".format(
|
||||||
|
config_file, resolved_config_file))
|
||||||
|
# Load config
|
||||||
|
config = GPT2Config.from_json_file(resolved_config_file)
|
||||||
|
logger.info("Model config {}".format(config))
|
||||||
|
# Instantiate model.
|
||||||
|
model = cls(config, *inputs, **kwargs)
|
||||||
|
if state_dict is None and not from_tf:
|
||||||
|
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||||
|
if from_tf:
|
||||||
|
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||||
|
return load_tf_weights_in_gpt2(model, resolved_archive_file)
|
||||||
|
|
||||||
|
old_keys = []
|
||||||
|
new_keys = []
|
||||||
|
for key in state_dict.keys():
|
||||||
|
new_key = None
|
||||||
|
if key.endswith(".g"):
|
||||||
|
new_key = key[:-2] + ".weight"
|
||||||
|
elif key.endswith(".b"):
|
||||||
|
new_key = key[:-2] + ".bias"
|
||||||
|
elif key.endswith(".w"):
|
||||||
|
new_key = key[:-2] + ".weight"
|
||||||
|
if new_key:
|
||||||
|
old_keys.append(key)
|
||||||
|
new_keys.append(new_key)
|
||||||
|
for old_key, new_key in zip(old_keys, new_keys):
|
||||||
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
|
missing_keys = []
|
||||||
|
unexpected_keys = []
|
||||||
|
error_msgs = []
|
||||||
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
def load(module, prefix=""):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
module._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, prefix + name + ".")
|
||||||
|
|
||||||
|
start_model = model
|
||||||
|
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
|
||||||
|
start_model = model.transformer
|
||||||
|
load(start_model, prefix="")
|
||||||
|
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
logger.info(
|
||||||
|
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
||||||
|
)
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.info(
|
||||||
|
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
|
||||||
|
)
|
||||||
|
if len(error_msgs) > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure we are still sharing the output and input embeddings after loading weights
|
||||||
|
model.set_tied()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2Model(GPT2PreTrainedModel):
|
||||||
|
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
|
||||||
|
|
||||||
|
Params:
|
||||||
|
config: a GPT2Config class instance with the configuration to build a new model
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||||
|
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||||
|
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
|
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||||
|
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
|
You can use it to add a third type of embedding to each input token in the sequence
|
||||||
|
(the previous two being the word and position embeddings).
|
||||||
|
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||||
|
self-attention block.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
`hidden_states`: the encoded-hidden-states at the top of the model
|
||||||
|
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
|
||||||
|
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into BPE token ids
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
|
||||||
|
config = modeling_gpt2.GPT2Config()
|
||||||
|
|
||||||
|
model = modeling_gpt2.GPT2Model(config)
|
||||||
|
hidden_states = model(input_ids)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GPT2Model, self).__init__(config)
|
||||||
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||||
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||||
|
block = Block(config.n_ctx, config, scale=True)
|
||||||
|
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||||
|
self.ln_f = LayerNorm(config.n_embd)
|
||||||
|
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
||||||
|
past_length = 0 if past is None else past[0][0].size(-2)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||||
|
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
|
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||||
|
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
else:
|
||||||
|
token_type_embeds = 0
|
||||||
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||||
|
presents = []
|
||||||
|
for block in self.h:
|
||||||
|
hidden_states, present = block(hidden_states)
|
||||||
|
presents.append(present)
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
return hidden_states.view(*output_shape), presents
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||||
|
"""OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").
|
||||||
|
|
||||||
|
Params:
|
||||||
|
config: a GPT2Config class instance with the configuration to build a new model
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
||||||
|
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||||
|
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
|
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||||
|
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
|
You can use it to add a third type of embedding to each input token in the sequence
|
||||||
|
(the previous two being the word and position embeddings).
|
||||||
|
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||||
|
self-attention block.
|
||||||
|
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
|
||||||
|
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||||
|
is only computed for the labels set in [0, ..., vocab_size]
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
if `lm_labels` is not `None`:
|
||||||
|
Outputs the language modeling loss.
|
||||||
|
else:
|
||||||
|
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
|
||||||
|
(or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into BPE token ids
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
|
||||||
|
config = modeling_gpt2.GPT2Config()
|
||||||
|
|
||||||
|
model = modeling_gpt2.GPT2LMHeadModel(config)
|
||||||
|
lm_logits = model(input_ids)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GPT2LMHeadModel, self).__init__(config)
|
||||||
|
self.transformer = GPT2Model(config)
|
||||||
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def set_tied(self):
|
||||||
|
""" Make sure we are sharing the embeddings
|
||||||
|
"""
|
||||||
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||||
|
|
||||||
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||||
|
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
if lm_labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||||
|
return loss
|
||||||
|
return lm_logits, presents
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||||
|
"""OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").
|
||||||
|
|
||||||
|
Params:
|
||||||
|
config: a GPT2Config class instance with the configuration to build a new model
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
|
||||||
|
indices selected in the range [0, config.vocab_size[
|
||||||
|
`mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
|
||||||
|
which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
|
||||||
|
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
|
with the position indices (selected in the range [0, config.n_positions - 1[.
|
||||||
|
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
|
You can use it to add a third type of embedding to each input token in the sequence
|
||||||
|
(the previous two being the word and position embeddings).
|
||||||
|
The input, position and token_type embeddings are summed inside the Transformer before the first
|
||||||
|
self-attention block.
|
||||||
|
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||||
|
with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
|
||||||
|
is only computed for the labels set in [0, ..., config.vocab_size]
|
||||||
|
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
|
||||||
|
with indices selected in [0, ..., num_choices].
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
if `lm_labels` and `multiple_choice_labels` are not `None`:
|
||||||
|
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
|
||||||
|
else: a tuple with
|
||||||
|
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
|
||||||
|
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into BPE token ids
|
||||||
|
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]]) # (bsz, number of choice, seq length)
|
||||||
|
mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)
|
||||||
|
|
||||||
|
config = modeling_gpt2.GPT2Config()
|
||||||
|
|
||||||
|
model = modeling_gpt2.GPT2LMHeadModel(config)
|
||||||
|
lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(GPT2DoubleHeadsModel, self).__init__(config)
|
||||||
|
self.transformer = GPT2Model(config)
|
||||||
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||||
|
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def set_tied(self):
|
||||||
|
""" Make sure we are sharing the embeddings
|
||||||
|
"""
|
||||||
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||||
|
|
||||||
|
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||||
|
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
|
losses = []
|
||||||
|
if lm_labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
|
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
||||||
|
if mc_labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||||
|
if losses:
|
||||||
|
return losses
|
||||||
|
return lm_logits, mc_logits, presents
|
||||||
199
pytorch_pretrained_bert/tokenization_gpt2.py
Normal file
199
pytorch_pretrained_bert/tokenization_gpt2.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes for OpenAI GPT."""
|
||||||
|
from __future__ import (absolute_import, division, print_function,
|
||||||
|
unicode_literals)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import regex as re
|
||||||
|
import sys
|
||||||
|
from io import open
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .file_utils import cached_path
|
||||||
|
from .tokenization import BasicTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||||
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||||
|
}
|
||||||
|
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||||
|
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||||
|
}
|
||||||
|
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||||
|
'gpt2': 1024,
|
||||||
|
}
|
||||||
|
VOCAB_NAME = 'vocab.json'
|
||||||
|
MERGES_NAME = 'merges.txt'
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def bytes_to_unicode():
|
||||||
|
"""
|
||||||
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
The reversible bpe codes work on unicode strings.
|
||||||
|
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||||
|
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||||
|
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||||
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
|
"""
|
||||||
|
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2**8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2**8+n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
def get_pairs(word):
|
||||||
|
"""Return set of symbol pairs in a word.
|
||||||
|
|
||||||
|
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||||
|
"""
|
||||||
|
pairs = set()
|
||||||
|
prev_char = word[0]
|
||||||
|
for char in word[1:]:
|
||||||
|
pairs.add((prev_char, char))
|
||||||
|
prev_char = char
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
class GPT2Tokenizer(object):
|
||||||
|
"""
|
||||||
|
GPT-2 BPE tokenizer. Peculiarities:
|
||||||
|
- Byte-level BPE
|
||||||
|
- argument special_tokens and function set_special_tokens:
|
||||||
|
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
|
"""
|
||||||
|
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||||
|
Download and cache the pre-trained model file if needed.
|
||||||
|
"""
|
||||||
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||||
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
else:
|
||||||
|
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||||
|
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||||
|
# redirect to the cache, if necessary
|
||||||
|
try:
|
||||||
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
|
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||||
|
except EnvironmentError:
|
||||||
|
logger.error(
|
||||||
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
|
"at this path or url.".format(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
vocab_file, merges_file))
|
||||||
|
return None
|
||||||
|
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||||
|
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||||
|
logger.info("loading merges file {}".format(merges_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
|
vocab_file, resolved_vocab_file))
|
||||||
|
logger.info("loading merges file {} from cache at {}".format(
|
||||||
|
merges_file, resolved_merges_file))
|
||||||
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||||
|
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||||
|
# than the number of positional embeddings
|
||||||
|
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||||
|
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||||
|
# Instantiate tokenizer.
|
||||||
|
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
|
||||||
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
self.encoder = json.load(open(vocab_file))
|
||||||
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
|
self.errors = errors # how to handle errors in decoding
|
||||||
|
self.byte_encoder = bytes_to_unicode()
|
||||||
|
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
||||||
|
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||||
|
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
||||||
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||||
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.encoder)
|
||||||
|
|
||||||
|
def bpe(self, token):
|
||||||
|
if token in self.cache:
|
||||||
|
return self.cache[token]
|
||||||
|
word = tuple(token)
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
return token
|
||||||
|
|
||||||
|
while True:
|
||||||
|
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||||
|
if bigram not in self.bpe_ranks:
|
||||||
|
break
|
||||||
|
first, second = bigram
|
||||||
|
new_word = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
try:
|
||||||
|
j = word.index(first, i)
|
||||||
|
new_word.extend(word[i:j])
|
||||||
|
i = j
|
||||||
|
except:
|
||||||
|
new_word.extend(word[i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||||
|
new_word.append(first+second)
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
new_word = tuple(new_word)
|
||||||
|
word = new_word
|
||||||
|
if len(word) == 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pairs = get_pairs(word)
|
||||||
|
word = ' '.join(word)
|
||||||
|
self.cache[token] = word
|
||||||
|
return word
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
bpe_tokens = []
|
||||||
|
for token in re.findall(self.pat, text):
|
||||||
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||||
|
return bpe_tokens
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
|
return text
|
||||||
213
tests/modeling_gpt2_test.py
Normal file
213
tests/modeling_gpt2_test.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
|
||||||
|
GPT2LMHeadModel, GPT2DoubleHeadsModel)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2ModelTest(unittest.TestCase):
|
||||||
|
class GPT2ModelTester(object):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_position_ids=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
n_special=1,
|
||||||
|
n_positions=33,
|
||||||
|
n_embd=32,
|
||||||
|
n_layer=5,
|
||||||
|
n_head=4,
|
||||||
|
n_choices=3,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
scope=None):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_position_ids = use_position_ids
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_special = n_special
|
||||||
|
self.n_positions = n_positions
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.n_choices = n_choices
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
position_ids = None
|
||||||
|
if self.use_position_ids:
|
||||||
|
position_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
total_voc = self.vocab_size + self.n_special
|
||||||
|
token_type_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
||||||
|
|
||||||
|
mc_labels = None
|
||||||
|
lm_labels = None
|
||||||
|
mc_token_ids = None
|
||||||
|
if self.use_labels:
|
||||||
|
mc_labels = GPT2ModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
lm_labels = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
||||||
|
mc_token_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length)
|
||||||
|
|
||||||
|
config = GPT2Config(
|
||||||
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
|
n_positions=self.n_positions,
|
||||||
|
n_special=self.n_special,
|
||||||
|
n_embd=self.n_embd,
|
||||||
|
n_layer=self.n_layer,
|
||||||
|
n_head=self.n_head,
|
||||||
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
|
return (config, input_ids, token_type_ids, position_ids,
|
||||||
|
mc_labels, lm_labels, mc_token_ids)
|
||||||
|
|
||||||
|
def create_gpt2_model(self, config, input_ids, token_type_ids, position_ids,
|
||||||
|
mc_labels, lm_labels, mc_token_ids):
|
||||||
|
model = GPT2Model(config)
|
||||||
|
model.eval()
|
||||||
|
hidden_states, presents = model(input_ids, position_ids, token_type_ids)
|
||||||
|
outputs = {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"presents": presents,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_gpt2_model_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["hidden_states"].size()),
|
||||||
|
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
|
||||||
|
|
||||||
|
|
||||||
|
def create_gpt2_lm_head(self, config, input_ids, token_type_ids, position_ids,
|
||||||
|
mc_labels, lm_labels, mc_token_ids):
|
||||||
|
model = GPT2LMHeadModel(config)
|
||||||
|
model.eval()
|
||||||
|
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||||
|
lm_logits, presents = model(input_ids, position_ids, token_type_ids)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"lm_logits": lm_logits,
|
||||||
|
"presents": presents,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_gpt2_lm_head_output(self, result):
|
||||||
|
total_voc = self.n_special + self.vocab_size
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["lm_logits"].size()),
|
||||||
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
|
|
||||||
|
def check_gpt2_lm_head_loss_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["loss"].size()),
|
||||||
|
[])
|
||||||
|
|
||||||
|
def create_gpt2_double_heads(self, config, input_ids, token_type_ids, position_ids,
|
||||||
|
mc_labels, lm_labels, mc_token_ids):
|
||||||
|
model = GPT2DoubleHeadsModel(config)
|
||||||
|
model.eval()
|
||||||
|
loss = model(input_ids, mc_token_ids,
|
||||||
|
lm_labels=lm_labels, mc_labels=mc_labels,
|
||||||
|
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||||
|
lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"lm_logits": lm_logits,
|
||||||
|
"mc_logits": mc_logits,
|
||||||
|
"presents": presents,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_gpt2_double_heads_output(self, result):
|
||||||
|
total_voc = self.n_special + self.vocab_size
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["lm_logits"].size()),
|
||||||
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["mc_logits"].size()),
|
||||||
|
[self.batch_size, self.n_choices])
|
||||||
|
|
||||||
|
def check_gpt2_double_heads_loss_output(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
[list(l.size()) for l in result["loss"]],
|
||||||
|
[[], []])
|
||||||
|
|
||||||
|
def test_default(self):
|
||||||
|
self.run_tester(GPT2ModelTest.GPT2ModelTester(self))
|
||||||
|
|
||||||
|
def test_config_to_json_string(self):
|
||||||
|
config = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37)
|
||||||
|
obj = json.loads(config.to_json_string())
|
||||||
|
self.assertEqual(obj["vocab_size"], 99)
|
||||||
|
self.assertEqual(obj["n_embd"], 37)
|
||||||
|
|
||||||
|
def run_tester(self, tester):
|
||||||
|
config_and_inputs = tester.prepare_config_and_inputs()
|
||||||
|
output_result = tester.create_gpt2_model(*config_and_inputs)
|
||||||
|
tester.check_gpt2_model_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_gpt2_lm_head(*config_and_inputs)
|
||||||
|
tester.check_gpt2_lm_head_output(output_result)
|
||||||
|
tester.check_gpt2_lm_head_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_gpt2_double_heads(*config_and_inputs)
|
||||||
|
tester.check_gpt2_double_heads_output(output_result)
|
||||||
|
tester.check_gpt2_double_heads_loss_output(output_result)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
|
if rng is None:
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
|
total_dims = 1
|
||||||
|
for dim in shape:
|
||||||
|
total_dims *= dim
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for _ in range(total_dims):
|
||||||
|
values.append(rng.randint(0, vocab_size - 1))
|
||||||
|
|
||||||
|
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
56
tests/tokenization_gpt2_test.py
Normal file
56
tests/tokenization_gpt2_test.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2TokenizationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||||
|
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||||
|
"w</w>", "r</w>", "t</w>",
|
||||||
|
"lo", "low", "er</w>",
|
||||||
|
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||||
|
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||||
|
json.dump(vocab_tokens, fp)
|
||||||
|
vocab_file = fp.name
|
||||||
|
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
merges_file = fp.name
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer(vocab_file, merges_file)
|
||||||
|
os.remove(vocab_file)
|
||||||
|
os.remove(merges_file)
|
||||||
|
|
||||||
|
text = "lower"
|
||||||
|
bpe_tokens = ["low", "er</w>"]
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
|
input_tokens = tokens
|
||||||
|
input_bpe_tokens = [14, 15, 20]
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user