Blenderbot (#7418)
Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -212,8 +212,9 @@ class AutoModelTest(unittest.TestCase):
|
||||
mapping = tuple(mapping.items())
|
||||
for index, (child_config, child_model) in enumerate(mapping[1:]):
|
||||
for parent_config, parent_model in mapping[: index + 1]:
|
||||
with self.subTest(
|
||||
msg="Testing if {} is child of {}".format(child_config.__name__, parent_config.__name__)
|
||||
):
|
||||
self.assertFalse(issubclass(child_config, parent_config))
|
||||
self.assertFalse(issubclass(child_model, parent_model))
|
||||
assert not issubclass(
|
||||
child_config, parent_config
|
||||
), "{child_config.__name__} is child of {parent_config.__name__}"
|
||||
assert not issubclass(
|
||||
child_model, parent_model
|
||||
), "{child_config.__name__} is child of {parent_config.__name__}"
|
||||
|
||||
@@ -40,6 +40,11 @@ if is_torch_available():
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
BertConfig,
|
||||
BlenderbotConfig,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
PegasusConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
@@ -175,7 +180,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
assert_tensors_close(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||
@@ -189,7 +194,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_features_with_long_encoder_mask = model(
|
||||
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
assert_tensors_close(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
|
||||
def test_save_load_strict(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -364,7 +369,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
]
|
||||
for ex, desired_result in zip(examples, fairseq_results):
|
||||
bart_toks = tokenizer.encode(ex, return_tensors="pt")
|
||||
_assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex)
|
||||
assert_tensors_close(desired_result.long(), bart_toks, prefix=ex)
|
||||
|
||||
def test_generate_fp16(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
@@ -411,16 +416,22 @@ class BartHeadTests(unittest.TestCase):
|
||||
self.assertTrue(torch.eq(input_new, output_new).all())
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b aren't both tensors, raise a nice Assertion error."""
|
||||
|
||||
if a is None and b is None:
|
||||
return True
|
||||
assert a.shape == b.shape
|
||||
try:
|
||||
if torch.allclose(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
msg = "{} != {}".format(a, b)
|
||||
pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
|
||||
if a.numel() > 100:
|
||||
msg = f"tensor values are {pct_different:.1%} percent different."
|
||||
else:
|
||||
msg = f"{a} != {b}"
|
||||
if prefix:
|
||||
msg = prefix + ": " + msg
|
||||
raise AssertionError(msg)
|
||||
@@ -496,8 +507,8 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
||||
with torch.no_grad():
|
||||
logits2 = model(**inputs_dict)[0]
|
||||
_assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE)
|
||||
_assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)
|
||||
assert_tensors_close(batched_logits[1], logits2, atol=TOLERANCE)
|
||||
assert_tensors_close(expected_slice, logits_arr, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_xsum_summarization_same_as_fairseq(self):
|
||||
@@ -633,3 +644,12 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_child_config_equivalence(self):
|
||||
"""Test that configs associated with children of BartForConditionalGeneration are identical."""
|
||||
child_classes = [BlenderbotConfig, MBartConfig, MarianConfig, PegasusConfig]
|
||||
parent_keys = BartConfig().to_dict().keys()
|
||||
for c in child_classes:
|
||||
assert c().to_dict().keys() == parent_keys # traceback is very nice on it's own
|
||||
# check that test is not stupid
|
||||
assert BertConfig().to_dict().keys() != parent_keys
|
||||
|
||||
215
tests/test_modeling_blenderbot.py
Normal file
215
tests/test_modeling_blenderbot.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Tests for BlenderBot"""
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BlenderbotConfig,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotSmallTokenizer,
|
||||
BlenderbotTokenizer,
|
||||
)
|
||||
|
||||
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BlenderbotModelTester:
|
||||
# Required attributes
|
||||
vocab_size = 99
|
||||
batch_size = 13
|
||||
seq_length = 7
|
||||
num_hidden_layers = 2
|
||||
hidden_size = 16
|
||||
num_attention_heads = 4
|
||||
is_training = True
|
||||
|
||||
def __init__(self, parent):
|
||||
torch.manual_seed(0)
|
||||
self.parent = parent
|
||||
self.config = BlenderbotConfig(
|
||||
d_model=self.hidden_size,
|
||||
dropout=0.0,
|
||||
activation_function="gelu",
|
||||
vocab_size=self.vocab_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
attention_dropout=0.0,
|
||||
encoder_ffn_dim=4,
|
||||
decoder_ffn_dim=4,
|
||||
do_blenderbot_90_layernorm=False,
|
||||
normalize_before=True,
|
||||
max_position_embeddings=50,
|
||||
static_position_embeddings=False,
|
||||
scale_embedding=True,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
num_beams=1,
|
||||
min_length=3,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return self.config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available():
|
||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,)
|
||||
all_model_classes = (BlenderbotForConditionalGeneration,)
|
||||
else:
|
||||
all_generative_model_classes = ()
|
||||
all_model_classes = ()
|
||||
is_encoder_decoder = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlenderbotModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_initialization_module(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = BlenderbotForConditionalGeneration(config).model
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
enc_embeds = model.encoder.embed_tokens.weight
|
||||
assert (enc_embeds == model.shared.weight).all().item()
|
||||
self.assertAlmostEqual(torch.std(enc_embeds).item(), config.init_std, 2)
|
||||
|
||||
def test_embed_pos_shape(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = BlenderbotForConditionalGeneration(config)
|
||||
expected_shape = (config.max_position_embeddings + config.extra_pos_embeddings, config.d_model)
|
||||
assert model.model.encoder.embed_positions.weight.shape == expected_shape
|
||||
model.model.decoder.embed_positions.weight.shape == expected_shape
|
||||
|
||||
@unittest.skip("This test is flaky")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
|
||||
@require_torch
|
||||
class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
ckpt = "facebook/blenderbot-3B"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return BlenderbotTokenizer.from_pretrained(self.ckpt)
|
||||
|
||||
@slow
|
||||
def test_generation_from_short_input_same_as_parlai_3B(self):
|
||||
|
||||
src_text = ["Sam"]
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs, **FASTER_GEN_KWARGS)
|
||||
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
|
||||
|
||||
generated_txt = self.tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
|
||||
assert generated_txt[0].strip() == tgt_text
|
||||
|
||||
@slow
|
||||
def test_generation_from_long_input_same_as_parlai_3B(self):
|
||||
|
||||
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||
|
||||
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
|
||||
generated_ids = self.model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
assert "I think it's because we are so worried about what people think of us." == reply.strip()
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
ckpt = "facebook/blenderbot-90M"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(self.ckpt).to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained(self.ckpt)
|
||||
|
||||
@slow
|
||||
def test_90_generation_from_long_input(self):
|
||||
|
||||
src_text = [
|
||||
"Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like\
|
||||
i'm going to throw up.\nand why is that?"
|
||||
]
|
||||
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
||||
assert self.model.config.do
|
||||
generated_ids = self.model.generate(**model_inputs)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
assert reply in (
|
||||
"i don't know. i just feel like i'm going to throw up. it's not fun.",
|
||||
"i'm not sure. i just feel like i've been feeling like i have to be in a certain place",
|
||||
)
|
||||
|
||||
def test_90_generation_from_short_input(self):
|
||||
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
# generated_txt = self.tokenizer.decode(generated_utterances[0])
|
||||
|
||||
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
|
||||
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
|
||||
assert clean_txt in (
|
||||
"have you ever been to a sam club? it's a great club in the south.",
|
||||
"have you ever heard of sam harris? he's an american singer, songwriter, and actor.",
|
||||
)
|
||||
@@ -752,6 +752,10 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
@@ -765,7 +769,9 @@ class ModelTesterMixin:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor
|
||||
from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -79,7 +79,17 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
||||
result_slice = logits[0, 0, :3]
|
||||
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
||||
assert_tensors_close(expected_slice, result_slice, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
["UN Chief Says There Is No Military Solution in Syria"]
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
# self.assertEqual(self.tgt_text[1], decoded[1])
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
|
||||
93
tests/test_tokenization_blenderbot.py
Normal file
93
tests/test_tokenization_blenderbot.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Tests for Blenderbot Tokenizers, including common tests for BlenderbotSmallTokenizer."""
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.tokenization_blenderbot import VOCAB_FILES_NAMES, BlenderbotSmallTokenizer, BlenderbotTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class BlenderbotSmallTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = BlenderbotSmallTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = ["__start__", "adapt", "act", "ap@@", "te", "__end__", "__unk__"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
merges = ["#version: 0.2", "a p", "t e</w>", "ap t</w>", "a d", "ad apt</w>", "a c", "ac t</w>", ""]
|
||||
self.special_tokens_map = {"unk_token": "__unk__", "bos_token": "__start__", "eos_token": "__end__"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return BlenderbotSmallTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "adapt act apte"
|
||||
output_text = "adapt act apte"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_blenderbot_small_tokenizer(self):
|
||||
tokenizer = BlenderbotSmallTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "adapt act apte"
|
||||
bpe_tokens = ["adapt", "act", "ap@@", "te"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = [tokenizer.bos_token] + tokens + [tokenizer.eos_token]
|
||||
|
||||
input_bpe_tokens = [0, 1, 2, 3, 4, 5]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_special_tokens_small_tok(self):
|
||||
tok = BlenderbotSmallTokenizer.from_pretrained("facebook/blenderbot-90M")
|
||||
assert tok("sam").input_ids == [1384]
|
||||
src_text = "I am a small frog."
|
||||
encoded = tok([src_text], padding=False, truncation=False)["input_ids"]
|
||||
decoded = tok.batch_decode(encoded, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
assert src_text != decoded # I wish it did!
|
||||
assert decoded == "i am a small frog ."
|
||||
|
||||
|
||||
class Blenderbot3BTokenizerTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def tokenizer_3b(self):
|
||||
return BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B")
|
||||
|
||||
def test_encode_decode_cycle(self):
|
||||
tok = self.tokenizer_3b
|
||||
src_text = " I am a small frog."
|
||||
encoded = tok([src_text], padding=False, truncation=False)["input_ids"]
|
||||
decoded = tok.batch_decode(encoded, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
assert src_text == decoded
|
||||
|
||||
def test_3B_tokenization_same_as_parlai(self):
|
||||
assert self.tokenizer_3b.add_prefix_space
|
||||
assert self.tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
|
||||
Reference in New Issue
Block a user