[BIG] name change
This commit is contained in:
0
pytorch_transformers/tests/__init__.py
Normal file
0
pytorch_transformers/tests/__init__.py
Normal file
19
pytorch_transformers/tests/conftest.py
Normal file
19
pytorch_transformers/tests/conftest.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# content of conftest.py
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--runslow", action="store_true", default=False, help="run slow tests"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--runslow"):
|
||||
# --runslow given in cli: do not skip slow tests
|
||||
return
|
||||
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
|
||||
for item in items:
|
||||
if "slow" in item.keywords:
|
||||
item.add_marker(skip_slow)
|
||||
1
pytorch_transformers/tests/fixtures/input.txt
vendored
Normal file
1
pytorch_transformers/tests/fixtures/input.txt
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Who was Jim Henson ? ||| Jim Henson was a puppeteer
|
||||
33
pytorch_transformers/tests/fixtures/sample_text.txt
vendored
Normal file
33
pytorch_transformers/tests/fixtures/sample_text.txt
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
|
||||
Text should be one-sentence-per-line, with empty lines between documents.
|
||||
This sample text is public domain and was randomly selected from Project Guttenberg.
|
||||
|
||||
The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
|
||||
Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
|
||||
Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
|
||||
"Cass" Beard had risen early that morning, but not with a view to discovery.
|
||||
A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
|
||||
The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
|
||||
This was nearly opposite.
|
||||
Mr. Cassius crossed the highway, and stopped suddenly.
|
||||
Something glittered in the nearest red pool before him.
|
||||
Gold, surely!
|
||||
But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
|
||||
Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
|
||||
Like most of his fellow gold-seekers, Cass was superstitious.
|
||||
|
||||
The fountain of classic wisdom, Hypatia herself.
|
||||
As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
|
||||
From my youth I felt in me a soul above the matter-entangled herd.
|
||||
She revealed to me the glorious fact, that I am a spark of Divinity itself.
|
||||
A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
|
||||
There is a philosophic pleasure in opening one's treasures to the modest young.
|
||||
Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
|
||||
Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
|
||||
but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
|
||||
Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
|
||||
His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
|
||||
while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
|
||||
At last they reached the quay at the opposite end of the street;
|
||||
and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
|
||||
He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
|
||||
BIN
pytorch_transformers/tests/fixtures/test_sentencepiece.model
vendored
Normal file
BIN
pytorch_transformers/tests/fixtures/test_sentencepiece.model
vendored
Normal file
Binary file not shown.
446
pytorch_transformers/tests/model_tests_commons.py
Normal file
446
pytorch_transformers/tests/model_tests_commons.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
if '_range' in key or '_std' in key:
|
||||
setattr(configs_no_init, key, 0.0)
|
||||
return configs_no_init
|
||||
|
||||
def _create_and_check_torchscript_output_attentions(tester, model_classes, config, inputs_dict):
|
||||
config.output_attentions = True
|
||||
_create_and_check_torchscript(tester, model_classes, config, inputs_dict)
|
||||
|
||||
def _create_and_check_torchscript_output_hidden_state(tester, model_classes, config, inputs_dict):
|
||||
config.output_hidden_states = True
|
||||
_create_and_check_torchscript(tester, model_classes, config, inputs_dict)
|
||||
|
||||
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.eval()
|
||||
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
|
||||
|
||||
try:
|
||||
torch.jit.trace(model, inputs)
|
||||
except RuntimeError:
|
||||
tester.parent.fail("Couldn't trace module.")
|
||||
|
||||
try:
|
||||
traced_gpt2 = torch.jit.trace(model, inputs)
|
||||
torch.jit.save(traced_gpt2, "traced_model.pt")
|
||||
except RuntimeError:
|
||||
tester.parent.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load("traced_model.pt")
|
||||
os.remove("traced_model.pt")
|
||||
except ValueError:
|
||||
tester.parent.fail("Couldn't load module.")
|
||||
|
||||
model.eval()
|
||||
loaded_model.eval()
|
||||
|
||||
model_params = model.parameters()
|
||||
loaded_model_params = loaded_model.parameters()
|
||||
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model_params, loaded_model_params):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
tester.parent.assertTrue(models_equal)
|
||||
|
||||
def _create_and_check_initialization(tester, model_classes, config, inputs_dict):
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
tester.parent.assertIn(param.data.mean().item(), [0.0, 1.0],
|
||||
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
for model_class in model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config=configs_no_init)
|
||||
model.eval()
|
||||
|
||||
# Prepare head_mask
|
||||
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
|
||||
head_mask = torch.ones(tester.num_hidden_layers, tester.num_attention_heads)
|
||||
head_mask[0, 0] = 0
|
||||
head_mask[-1, :-1] = 0
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
inputs = inputs_dict.copy()
|
||||
inputs['head_mask'] = head_mask
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Test that we can get a gradient back for importance score computation
|
||||
output = sum(t.sum() for t in outputs[0])
|
||||
output = output.sum()
|
||||
output.backward()
|
||||
multihead_outputs = head_mask.grad
|
||||
|
||||
attentions = outputs[-1]
|
||||
hidden_states = outputs[-2]
|
||||
|
||||
# Remove Nan
|
||||
|
||||
tester.parent.assertIsNotNone(multihead_outputs)
|
||||
tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers)
|
||||
tester.parent.assertAlmostEqual(
|
||||
attentions[0][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||
tester.parent.assertNotEqual(
|
||||
attentions[0][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
tester.parent.assertNotEqual(
|
||||
attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
|
||||
tester.parent.assertAlmostEqual(
|
||||
attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
|
||||
tester.parent.assertNotEqual(
|
||||
attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
|
||||
|
||||
def _create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict):
|
||||
for model_class in model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config=config)
|
||||
model.eval()
|
||||
heads_to_prune = {0: list(range(1, tester.num_attention_heads)),
|
||||
-1: [0]}
|
||||
model.prune_heads(heads_to_prune)
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
attentions = outputs[-1]
|
||||
|
||||
tester.parent.assertEqual(
|
||||
attentions[0].shape[-3], 1)
|
||||
tester.parent.assertEqual(
|
||||
attentions[1].shape[-3], tester.num_attention_heads)
|
||||
tester.parent.assertEqual(
|
||||
attentions[-1].shape[-3], tester.num_attention_heads - 1)
|
||||
|
||||
|
||||
def _create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
|
||||
for model_class in model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
tester.parent.assertEqual(model.config.output_attentions, True)
|
||||
tester.parent.assertEqual(model.config.output_hidden_states, False)
|
||||
tester.parent.assertEqual(len(attentions), tester.num_hidden_layers)
|
||||
tester.parent.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[tester.num_attention_heads,
|
||||
tester.seq_length,
|
||||
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(**inputs_dict)
|
||||
tester.parent.assertEqual(out_len+1, len(outputs))
|
||||
tester.parent.assertEqual(model.config.output_attentions, True)
|
||||
tester.parent.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
attentions = outputs[-1]
|
||||
tester.parent.assertEqual(len(attentions), tester.num_hidden_layers)
|
||||
tester.parent.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[tester.num_attention_heads,
|
||||
tester.seq_length,
|
||||
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
|
||||
|
||||
def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_dict):
|
||||
for model_class in model_classes:
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = False
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(**inputs_dict)
|
||||
hidden_states = outputs[-1]
|
||||
tester.parent.assertEqual(model.config.output_attentions, False)
|
||||
tester.parent.assertEqual(model.config.output_hidden_states, True)
|
||||
tester.parent.assertEqual(len(hidden_states), tester.num_hidden_layers + 1)
|
||||
tester.parent.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[tester.seq_length, tester.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True, test_torchscript=True):
|
||||
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
|
||||
|
||||
if test_torchscript:
|
||||
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
|
||||
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
|
||||
|
||||
if test_pruning:
|
||||
_create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
|
||||
|
||||
|
||||
class ConfigTester(object):
|
||||
def __init__(self, parent, config_class=None, **kwargs):
|
||||
self.parent = parent
|
||||
self.config_class = config_class
|
||||
self.inputs_dict = kwargs
|
||||
|
||||
def create_and_test_config_common_properties(self):
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
self.parent.assertTrue(hasattr(config, 'hidden_size'))
|
||||
self.parent.assertTrue(hasattr(config, 'num_attention_heads'))
|
||||
self.parent.assertTrue(hasattr(config, 'num_hidden_layers'))
|
||||
|
||||
def create_and_test_config_to_json_string(self):
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
obj = json.loads(config.to_json_string())
|
||||
for key, value in self.inputs_dict.items():
|
||||
self.parent.assertEqual(obj[key], value)
|
||||
|
||||
def create_and_test_config_to_json_file(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
json_file_path = "/tmp/config.json"
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = self.config_class.from_json_file(json_file_path)
|
||||
os.remove(json_file_path)
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def run_common_tests(self):
|
||||
self.create_and_test_config_common_properties()
|
||||
self.create_and_test_config_to_json_string()
|
||||
self.create_and_test_config_to_json_file()
|
||||
|
||||
|
||||
class GPTModelTester(object):
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_position_ids=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
n_special=1,
|
||||
n_positions=33,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
n_choices=3,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
config_class=None,
|
||||
base_model_class=None,
|
||||
lm_head_model_class=None,
|
||||
double_head_model_class=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_position_ids = use_position_ids
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_special = n_special
|
||||
self.n_positions = n_positions
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_choices = n_choices
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.config_class = config_class
|
||||
self.base_model_class = base_model_class
|
||||
self.lm_head_model_class = lm_head_model_class
|
||||
self.double_head_model_class = double_head_model_class
|
||||
self.all_model_classes = (base_model_class, lm_head_model_class, double_head_model_class)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
total_num_tokens = self.vocab_size + self.n_special
|
||||
input_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
|
||||
|
||||
position_ids = None
|
||||
if self.use_position_ids:
|
||||
position_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
total_voc = self.vocab_size
|
||||
token_type_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
||||
|
||||
mc_labels = None
|
||||
lm_labels = None
|
||||
mc_token_ids = None
|
||||
if self.use_labels:
|
||||
mc_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
lm_labels = ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
|
||||
mc_token_ids = ids_tensor([self.batch_size, self.n_choices], self.seq_length)
|
||||
|
||||
config = self.config_class(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_special=self.n_special,
|
||||
n_positions=self.n_positions,
|
||||
n_embd=self.hidden_size,
|
||||
n_layer=self.num_hidden_layers,
|
||||
n_head=self.num_attention_heads,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return (config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids)
|
||||
|
||||
def create_and_check_base_model(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = self.base_model_class(config)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids, position_ids, token_type_ids)
|
||||
outputs = model(input_ids, position_ids)
|
||||
outputs = model(input_ids)
|
||||
|
||||
hidden_state = outputs[0]
|
||||
self.parent.assertListEqual(
|
||||
list(hidden_state.size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_lm_head(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = self.lm_head_model_class(config)
|
||||
model.eval()
|
||||
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||
loss, lm_logits = outputs[:2]
|
||||
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(lm_logits.size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
self.parent.assertListEqual(
|
||||
list(loss.size()),
|
||||
[])
|
||||
|
||||
def create_and_check_presents(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(input_ids)
|
||||
presents = outputs[-1]
|
||||
self.parent.assertEqual(self.num_hidden_layers, len(presents))
|
||||
self.parent.assertListEqual(
|
||||
list(presents[0].size()),
|
||||
[2, self.batch_size * self.n_choices, self.num_attention_heads,
|
||||
self.seq_length, self.hidden_size // self.num_attention_heads])
|
||||
|
||||
def create_and_check_double_heads(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = self.double_head_model_class(config)
|
||||
model.eval()
|
||||
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
|
||||
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
|
||||
loss = [lm_loss, mc_loss]
|
||||
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(lm_logits.size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
self.parent.assertListEqual(
|
||||
list(mc_logits.size()),
|
||||
[self.batch_size, self.n_choices])
|
||||
self.parent.assertListEqual(
|
||||
[list(l.size()) for l in loss],
|
||||
[[], []])
|
||||
|
||||
def create_and_check_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.parent.assertIsNotNone(model)
|
||||
|
||||
def create_and_check_commons(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
inputs_dict = {'input_ids': input_ids}
|
||||
create_and_check_commons(self, config, inputs_dict)
|
||||
|
||||
def run_common_tests(self, test_presents=False):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_base_model(*config_and_inputs)
|
||||
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_lm_head(*config_and_inputs)
|
||||
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_double_heads(*config_and_inputs)
|
||||
|
||||
if test_presents:
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_presents(*config_and_inputs)
|
||||
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_commons(*config_and_inputs)
|
||||
|
||||
def run_slow_tests(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
self.create_and_check_model_from_pretrained(*config_and_inputs)
|
||||
|
||||
50
pytorch_transformers/tests/model_utils_test.py
Normal file
50
pytorch_transformers/tests/model_utils_test.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 HuggingFace Inc..
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import PretrainedConfig, PreTrainedModel
|
||||
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, PretrainedConfig)
|
||||
|
||||
model = BertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, PreTrainedModel)
|
||||
|
||||
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
304
pytorch_transformers/tests/modeling_bert_test.py
Normal file
304
pytorch_transformers/tests/modeling_bert_test.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||
BertForNextSentencePrediction, BertForPreTraining,
|
||||
BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification, BertForMultipleChoice)
|
||||
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||
|
||||
|
||||
class BertModelTest(unittest.TestCase):
|
||||
class BertModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||
BertForTokenClassification),
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.all_model_classes = all_model_classes
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = BertConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertModel(config=config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForNextSentencePrediction(config=config)
|
||||
model.eval()
|
||||
loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["seq_relationship_score"].size()),
|
||||
[self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["seq_relationship_score"].size()),
|
||||
[self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForQuestionAnswering(config=config)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = BertForSequenceClassification(config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = BertForTokenClassification(config=config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_choices = self.num_choices
|
||||
model = BertForMultipleChoice(config=config)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(multiple_choice_inputs_ids,
|
||||
multiple_choice_token_type_ids,
|
||||
multiple_choice_input_mask,
|
||||
choice_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_bert_commons(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
create_and_check_commons(self, config, inputs_dict)
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(BertModelTest.BertModelTester(self))
|
||||
|
||||
def test_config(self):
|
||||
config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
|
||||
config_tester.run_common_tests()
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_model(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_question_answering(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_bert_commons(*config_and_inputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
53
pytorch_transformers/tests/modeling_gpt2_test.py
Normal file
53
pytorch_transformers/tests/modeling_gpt2_test.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (GPT2Config, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel)
|
||||
|
||||
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
|
||||
|
||||
class GPT2ModelTest(unittest.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
|
||||
config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
model_tester = GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
|
||||
lm_head_model_class=GPT2LMHeadModel,
|
||||
double_head_model_class=GPT2DoubleHeadsModel)
|
||||
model_tester.run_common_tests(test_presents=True)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_pretrained(self):
|
||||
model_tester = GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
|
||||
lm_head_model_class=GPT2LMHeadModel,
|
||||
double_head_model_class=GPT2DoubleHeadsModel)
|
||||
model_tester.run_slow_tests()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
49
pytorch_transformers/tests/modeling_openai_test.py
Normal file
49
pytorch_transformers/tests/modeling_openai_test.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||
|
||||
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
|
||||
|
||||
class OpenAIModelTest(unittest.TestCase):
|
||||
|
||||
def test_config(self):
|
||||
config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37)
|
||||
config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
model_tester = GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
|
||||
lm_head_model_class=OpenAIGPTLMHeadModel,
|
||||
double_head_model_class=OpenAIGPTDoubleHeadsModel)
|
||||
model_tester.run_common_tests(test_presents=False)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_pretrained(self):
|
||||
model_tester = GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
|
||||
lm_head_model_class=OpenAIGPTLMHeadModel,
|
||||
double_head_model_class=OpenAIGPTDoubleHeadsModel)
|
||||
model_tester.run_slow_tests()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
212
pytorch_transformers/tests/modeling_transfo_xl_test.py
Normal file
212
pytorch_transformers/tests/modeling_transfo_xl_test.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
|
||||
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||
|
||||
class TransfoXLModelTest(unittest.TestCase):
|
||||
class TransfoXLModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
mem_len=30,
|
||||
clamp_len=15,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
cutoffs=[10, 50, 80],
|
||||
hidden_size=32,
|
||||
d_embed=32,
|
||||
num_attention_heads=4,
|
||||
d_head=8,
|
||||
d_inner=128,
|
||||
div_val=2,
|
||||
num_hidden_layers=5,
|
||||
scope=None,
|
||||
seed=1,
|
||||
all_model_classes=(TransfoXLModel, TransfoXLLMHeadModel),
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.mem_len = mem_len
|
||||
self.key_len = seq_length + mem_len
|
||||
self.clamp_len = clamp_len
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.cutoffs = cutoffs
|
||||
self.hidden_size = hidden_size
|
||||
self.d_embed = d_embed
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_head = d_head
|
||||
self.d_inner = d_inner
|
||||
self.div_val = div_val
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.scope = scope
|
||||
self.seed = seed
|
||||
self.all_model_classes = all_model_classes
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = TransfoXLConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
mem_len=self.mem_len,
|
||||
clamp_len=self.clamp_len,
|
||||
cutoffs=self.cutoffs,
|
||||
d_model=self.hidden_size,
|
||||
d_embed=self.d_embed,
|
||||
n_head=self.num_attention_heads,
|
||||
d_head=self.d_head,
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
model = TransfoXLModel(config)
|
||||
model.eval()
|
||||
|
||||
hidden_states_1, mems_1 = model(input_ids_1)
|
||||
hidden_states_2, mems_2 = model(input_ids_2, mems_1)
|
||||
outputs = {
|
||||
"hidden_states_1": hidden_states_1,
|
||||
"mems_1": mems_1,
|
||||
"hidden_states_2": hidden_states_2,
|
||||
"mems_2": mems_2,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
|
||||
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
model = TransfoXLLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
lm_logits_1, mems_1 = model(input_ids_1)
|
||||
loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
|
||||
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
|
||||
loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
|
||||
|
||||
outputs = {
|
||||
"loss_1": loss_1,
|
||||
"mems_1": mems_1,
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"loss_2": loss_2,
|
||||
"mems_2": mems_2,
|
||||
"lm_logits_2": lm_logits_2,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_1"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_2"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
create_and_check_commons(self, config, inputs_dict, test_pruning=False, test_torchscript=False)
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
|
||||
|
||||
def test_config(self):
|
||||
config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
|
||||
config_tester.run_common_tests()
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_transfo_xl_model(*config_and_inputs)
|
||||
tester.check_transfo_xl_model_output(output_result)
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||
tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_transfo_xl_commons(*config_and_inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
282
pytorch_transformers/tests/modeling_xlm_test.py
Normal file
282
pytorch_transformers/tests/modeling_xlm_test.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
|
||||
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
|
||||
|
||||
|
||||
class XLMModelTest(unittest.TestCase):
|
||||
class XLMModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_lengths=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
gelu_activation=True,
|
||||
sinusoidal_embeddings=False,
|
||||
causal=False,
|
||||
asm=False,
|
||||
n_langs=2,
|
||||
vocab_size=99,
|
||||
n_special=0,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
summary_type="last",
|
||||
use_proj=True,
|
||||
scope=None,
|
||||
all_model_classes = (XLMModel, XLMWithLMHeadModel,
|
||||
XLMForQuestionAnswering, XLMForSequenceClassification), # , XLMForSequenceClassification, XLMForTokenClassification),
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_lengths = use_input_lengths
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.gelu_activation = gelu_activation
|
||||
self.sinusoidal_embeddings = sinusoidal_embeddings
|
||||
self.asm = asm
|
||||
self.n_langs = n_langs
|
||||
self.vocab_size = vocab_size
|
||||
self.n_special = n_special
|
||||
self.summary_type = summary_type
|
||||
self.causal = causal
|
||||
self.use_proj = use_proj
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.n_langs = n_langs
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.summary_type = summary_type
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.all_model_classes = all_model_classes
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
|
||||
|
||||
input_lengths = None
|
||||
if self.use_input_lengths:
|
||||
input_lengths = ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 # small variation of seq_length
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
is_impossible_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
|
||||
|
||||
config = XLMConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_special=self.n_special,
|
||||
emb_dim=self.hidden_size,
|
||||
n_layers=self.num_hidden_layers,
|
||||
n_heads=self.num_attention_heads,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
gelu_activation=self.gelu_activation,
|
||||
sinusoidal_embeddings=self.sinusoidal_embeddings,
|
||||
asm=self.asm,
|
||||
causal=self.causal,
|
||||
n_langs=self.n_langs,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
summary_type=self.summary_type,
|
||||
use_proj=self.use_proj)
|
||||
|
||||
return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
model = XLMModel(config=config)
|
||||
model.eval()
|
||||
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
|
||||
outputs = model(input_ids, langs=token_type_ids)
|
||||
outputs = model(input_ids)
|
||||
sequence_output = outputs[0]
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
model = XLMWithLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
||||
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
model = XLMForQuestionAnswering(config)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids)
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels,
|
||||
p_mask=input_mask)
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, cls_logits = outputs
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
|
||||
total_loss, start_logits, end_logits = outputs
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
"cls_logits": cls_logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["cls_logits"].size()),
|
||||
[self.batch_size])
|
||||
|
||||
|
||||
def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
model = XLMForSequenceClassification(config)
|
||||
model.eval()
|
||||
|
||||
(logits,) = model(input_ids)
|
||||
loss, logits = model(input_ids, labels=sequence_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.type_sequence_label_size])
|
||||
|
||||
|
||||
def create_and_check_xlm_commons(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
|
||||
create_and_check_commons(self, config, inputs_dict)
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(XLMModelTest.XLMModelTester(self))
|
||||
|
||||
def test_config(self):
|
||||
config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37)
|
||||
config_tester.run_common_tests()
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def run_tester(self, tester):
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlm_model(*config_and_inputs)
|
||||
|
||||
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||
# tester.create_and_check_xlm_for_masked_lm(*config_and_inputs)
|
||||
|
||||
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||
# tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||
# tester.create_and_check_xlm_for_question_answering(*config_and_inputs)
|
||||
|
||||
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||
# tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
# config_and_inputs = tester.prepare_config_and_inputs()
|
||||
# tester.create_and_check_xlm_for_token_classification(*config_and_inputs)
|
||||
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlm_commons(*config_and_inputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
310
pytorch_transformers/tests/modeling_xlnet_test.py
Normal file
310
pytorch_transformers/tests/modeling_xlnet_test.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
|
||||
|
||||
class XLNetModelTest(unittest.TestCase):
|
||||
class XLNetModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
mem_len=10,
|
||||
clamp_len=-1,
|
||||
reuse_len=15,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
cutoffs=[10, 50, 80],
|
||||
hidden_size=32,
|
||||
num_attention_heads=4,
|
||||
d_inner=128,
|
||||
num_hidden_layers=5,
|
||||
max_position_embeddings=10,
|
||||
type_sequence_label_size=2,
|
||||
untie_r=True,
|
||||
bi_data=False,
|
||||
same_length=False,
|
||||
initializer_range=0.05,
|
||||
seed=1,
|
||||
type_vocab_size=2,
|
||||
all_model_classes=(XLNetModel, XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering),
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.mem_len = mem_len
|
||||
# self.key_len = seq_length + mem_len
|
||||
self.clamp_len = clamp_len
|
||||
self.reuse_len = reuse_len
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.cutoffs = cutoffs
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_inner = d_inner
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.bi_data = bi_data
|
||||
self.untie_r = untie_r
|
||||
self.same_length = same_length
|
||||
self.initializer_range = initializer_range
|
||||
self.seed = seed
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.all_model_classes = all_model_classes
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float()
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
|
||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
inp_q = target_mapping[:, 0, :].clone() # predict last token
|
||||
|
||||
sequence_labels = None
|
||||
lm_labels = None
|
||||
is_impossible_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
|
||||
|
||||
config = XLNetConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
n_head=self.num_attention_heads,
|
||||
d_inner=self.d_inner,
|
||||
n_layer=self.num_hidden_layers,
|
||||
untie_r=self.untie_r,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
mem_len=self.mem_len,
|
||||
clamp_len=self.clamp_len,
|
||||
same_length=self.same_length,
|
||||
reuse_len=self.reuse_len,
|
||||
bi_data=self.bi_data,
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.type_sequence_label_size)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetModel(config)
|
||||
model.eval()
|
||||
|
||||
_, _ = model(input_ids_1, input_mask=input_mask)
|
||||
_, _ = model(input_ids_1, attention_mask=input_mask)
|
||||
_, _ = model(input_ids_1, token_type_ids=segment_ids)
|
||||
outputs, mems_1 = model(input_ids_1)
|
||||
|
||||
result = {
|
||||
"mems_1": mems_1,
|
||||
"outputs": outputs,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
|
||||
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
|
||||
|
||||
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping, inp_q=inp_q)
|
||||
|
||||
result = {
|
||||
"loss_1": loss_1,
|
||||
"mems_1": mems_1,
|
||||
"all_logits_1": all_logits_1,
|
||||
"loss_2": loss_2,
|
||||
"mems_2": mems_2,
|
||||
"all_logits_2": all_logits_2,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_1"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_1"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_2"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids_1)
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels,
|
||||
p_mask=input_mask)
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, cls_logits, mems = outputs
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, mems = outputs
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
"cls_logits": cls_logits,
|
||||
"mems": mems,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["cls_logits"].size()),
|
||||
[self.batch_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = XLNetForSequenceClassification(config)
|
||||
model.eval()
|
||||
|
||||
logits, mems_1 = model(input_ids_1)
|
||||
loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"mems_1": mems_1,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.type_sequence_label_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, inp_q, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
create_and_check_commons(self, config, inputs_dict, test_pruning=False)
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(XLNetModelTest.XLNetModelTester(self))
|
||||
|
||||
def test_config(self):
|
||||
config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
|
||||
config_tester.run_common_tests()
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def run_tester(self, tester):
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlnet_lm_head(*config_and_inputs)
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlnet_qa(*config_and_inputs)
|
||||
|
||||
tester.set_seed()
|
||||
config_and_inputs = tester.prepare_config_and_inputs()
|
||||
tester.create_and_check_xlnet_commons(*config_and_inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
91
pytorch_transformers/tests/optimization_test.py
Normal file
91
pytorch_transformers/tests/optimization_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import BertAdam
|
||||
from pytorch_transformers import OpenAIAdam
|
||||
from pytorch_transformers.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \
|
||||
WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
self.assertEqual(len(list1), len(list2))
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
def test_adam(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = torch.nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = BertAdam(params=[w], lr=2e-1,
|
||||
weight_decay=0.0,
|
||||
max_grad_norm=-1)
|
||||
for _ in range(100):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
|
||||
w.grad.zero_()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
def test_bert_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
def test_openai_sched_init(self):
|
||||
m = torch.nn.Linear(50, 50)
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none")
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR))
|
||||
optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000)
|
||||
self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule))
|
||||
# shouldn't fail
|
||||
|
||||
|
||||
class WarmupCosineWithRestartsTest(unittest.TestCase):
|
||||
def test_it(self):
|
||||
m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5)
|
||||
x = np.arange(0, 1000)
|
||||
y = [m.get_lr(xe) for xe in x]
|
||||
y = np.asarray(y)
|
||||
expected_zeros = y[[0, 200, 400, 600, 800]]
|
||||
print(expected_zeros)
|
||||
expected_ones = y[[50, 250, 450, 650, 850]]
|
||||
print(expected_ones)
|
||||
self.assertTrue(np.allclose(expected_ones, 1))
|
||||
self.assertTrue(np.allclose(expected_zeros, 0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
131
pytorch_transformers/tests/tokenization_bert_test.py
Normal file
131
pytorch_transformers/tests/tokenization_bert_test.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
WordpieceTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class TokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
os.remove(vocab_file)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
|
||||
[u"ah", u"\u535A", u"\u63A8", u"zz"])
|
||||
|
||||
def test_basic_tokenizer_lower(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_basic_tokenizer_no_lower(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
def test_wordpiece_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing"
|
||||
]
|
||||
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = WordpieceTokenizer(vocab=vocab)
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize(""), [])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize("unwanted running"),
|
||||
["un", "##want", "##ed", "runn", "##ing"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(_is_whitespace(u" "))
|
||||
self.assertTrue(_is_whitespace(u"\t"))
|
||||
self.assertTrue(_is_whitespace(u"\r"))
|
||||
self.assertTrue(_is_whitespace(u"\n"))
|
||||
self.assertTrue(_is_whitespace(u"\u00A0"))
|
||||
|
||||
self.assertFalse(_is_whitespace(u"A"))
|
||||
self.assertFalse(_is_whitespace(u"-"))
|
||||
|
||||
def test_is_control(self):
|
||||
self.assertTrue(_is_control(u"\u0005"))
|
||||
|
||||
self.assertFalse(_is_control(u"A"))
|
||||
self.assertFalse(_is_control(u" "))
|
||||
self.assertFalse(_is_control(u"\t"))
|
||||
self.assertFalse(_is_control(u"\r"))
|
||||
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(_is_punctuation(u"-"))
|
||||
self.assertTrue(_is_punctuation(u"$"))
|
||||
self.assertTrue(_is_punctuation(u"`"))
|
||||
self.assertTrue(_is_punctuation(u"."))
|
||||
|
||||
self.assertFalse(_is_punctuation(u"A"))
|
||||
self.assertFalse(_is_punctuation(u" "))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
68
pytorch_transformers/tests/tokenization_gpt2_test.py
Normal file
68
pytorch_transformers/tests/tokenization_gpt2_test.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class GPT2TokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"lo", "low", "er",
|
||||
"low", "lowest", "newer", "wider"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [13, 12, 16]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
# @pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
71
pytorch_transformers/tests/tokenization_openai_test.py
Normal file
71
pytorch_transformers/tests/tokenization_openai_test.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
|
||||
class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
81
pytorch_transformers/tests/tokenization_tests_commons.py
Normal file
81
pytorch_transformers/tests/tokenization_tests_commons.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
unicode = str
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
vocab_path="/tmp/"
|
||||
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
|
||||
tokenizer = tokenizer.from_pretrained(vocab_path)
|
||||
|
||||
for f in output_files:
|
||||
os.remove(f)
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
tester.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
|
||||
text = "Munich and Berlin are nice cities"
|
||||
filename = u"/tmp/tokenizer.bin"
|
||||
|
||||
subwords = tokenizer.tokenize(text)
|
||||
|
||||
pickle.dump(tokenizer, open(filename, "wb"))
|
||||
|
||||
tokenizer_new = pickle.load(open(filename, "rb"))
|
||||
subwords_loaded = tokenizer_new.tokenize(text)
|
||||
|
||||
tester.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(text)
|
||||
tester.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
|
||||
tester.assertNotEqual(len(tokens_2), 0)
|
||||
tester.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
71
pytorch_transformers/tests/tokenization_transfo_xl_test.py
Normal file
71
pytorch_transformers/tests/tokenization_transfo_xl_test.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
|
||||
]
|
||||
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
def test_full_tokenizer_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
|
||||
def test_full_tokenizer_no_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
70
pytorch_transformers/tests/tokenization_xlm_test.py
Normal file
70
pytorch_transformers/tests/tokenization_xlm_test.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
class XLMTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
vocab_file = fp.name
|
||||
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
merges_file = fp.name
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = XLMTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
90
pytorch_transformers/tests/tokenization_xlnet_test.py
Normal file
90
pytorch_transformers/tests/tokenization_xlnet_test.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer,
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP,
|
||||
SPIECE_UNDERLINE)
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fixtures/test_sentencepiece.model')
|
||||
|
||||
class XLNetTokenizationTest(unittest.TestCase):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
|
||||
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
def test_tokenizer_lower(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
||||
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
|
||||
|
||||
def test_tokenizer_no_lower(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
|
||||
u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user