[fsmt] rewrite SinusoidalPositionalEmbedding + USE_CUDA test fixes + new TranslationPipeline test (#7224)
* fix USE_CUDA, add pipeline * USE_CUDA fix * recode SinusoidalPositionalEmbedding into nn.Embedding subclass was needed for torchscript to work - this is now part of the state_dict, so will have to remove these keys during save_pretrained * back out (ci debug) * restore * slow last? * facilitate not saving certain keys and test * remove no longer used keys * style * fix logging import * cleanup * Update src/transformers/modeling_utils.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * fix bug in max_positional_embeddings * rename keys to keys_to_never_save per suggestion, improve the setup * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -31,16 +31,14 @@ import torch
|
||||
from fairseq import hub_utils
|
||||
from fairseq.data.dictionary import Dictionary
|
||||
|
||||
from transformers import WEIGHTS_NAME
|
||||
from transformers import WEIGHTS_NAME, logging
|
||||
from transformers.configuration_fsmt import FSMTConfig
|
||||
from transformers.modeling_fsmt import FSMTForConditionalGeneration
|
||||
from transformers.tokenization_fsmt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
@@ -229,6 +227,8 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
|
||||
"model.decoder.version",
|
||||
"model.encoder_embed_tokens.weight",
|
||||
"model.decoder_embed_tokens.weight",
|
||||
"model.encoder.embed_positions._float_tensor",
|
||||
"model.decoder.embed_positions._float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
@@ -397,24 +397,18 @@ class FSMTEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: FSMTConfig, embed_tokens):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.encoder_layerdrop
|
||||
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.embed_tokens = embed_tokens
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
# print(config.max_position_embeddings, embed_dim, self.padding_idx)
|
||||
num_embeddings = config.src_vocab_size
|
||||
self.embed_positions = SinusoidalPositionalEmbedding(
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
init_size=num_embeddings + self.padding_idx + 1, # removed: config.max_position_embeddings
|
||||
config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[EncoderLayer(config) for _ in range(config.encoder_layers)]
|
||||
) # type: List[EncoderLayer]
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
|
||||
@@ -570,15 +564,11 @@ class FSMTDecoder(nn.Module):
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_tokens = embed_tokens
|
||||
embed_dim = embed_tokens.embedding_dim
|
||||
num_embeddings = config.tgt_vocab_size
|
||||
self.embed_positions = SinusoidalPositionalEmbedding(
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
init_size=num_embeddings + self.padding_idx + 1,
|
||||
config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
@@ -1003,6 +993,14 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
)
|
||||
class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
base_model_prefix = "model"
|
||||
authorized_missing_keys = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: FSMTConfig):
|
||||
super().__init__(config)
|
||||
@@ -1137,36 +1135,34 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
return self.model.decoder.embed_tokens
|
||||
|
||||
|
||||
def make_positions(tensor, padding_idx: int):
|
||||
"""Replace non-padding symbols with their position numbers.
|
||||
|
||||
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||
class SinusoidalPositionalEmbedding(nn.Embedding):
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA. In particular XLA
|
||||
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
||||
# how to handle the dtype kwarg in cumsum.
|
||||
mask = tensor.ne(padding_idx).int()
|
||||
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
|
||||
This module produces sinusoidal positional embeddings of any length.
|
||||
|
||||
|
||||
class SinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length.
|
||||
We don't want to save the weight of this embedding since it's not trained
|
||||
(deterministic) and it can be huge.
|
||||
|
||||
Padding symbols are ignored.
|
||||
|
||||
These embeddings get automatically extended in forward if more positions is needed.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.weights = SinusoidalPositionalEmbedding.get_embedding(init_size, embedding_dim, padding_idx)
|
||||
self.register_buffer("_float_tensor", torch.zeros(1)) # used for getting the right device
|
||||
self.max_positions = int(1e5)
|
||||
def __init__(self, num_positions, embedding_dim, padding_idx):
|
||||
self.make_weight(num_positions, embedding_dim, padding_idx)
|
||||
|
||||
def make_weight(self, num_positions, embedding_dim, padding_idx, device=None):
|
||||
weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
|
||||
if device is not None:
|
||||
weight = weight.to(device)
|
||||
if not hasattr(self, "weight"):
|
||||
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.weight.detach_()
|
||||
self.weight.requires_grad = False
|
||||
|
||||
# XXX: bart uses s/num_embeddings/num_positions/, s/weights/weight/ - could make those match
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
def get_embedding(num_embeddings, embedding_dim, padding_idx):
|
||||
"""Build sinusoidal embeddings.
|
||||
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
@@ -1184,28 +1180,30 @@ class SinusoidalPositionalEmbedding(nn.Module):
|
||||
emb[padding_idx, :] = 0
|
||||
return emb
|
||||
|
||||
@staticmethod
|
||||
def make_positions(tensor, padding_idx: int):
|
||||
"""Replace non-padding symbols with their position numbers.
|
||||
|
||||
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully
|
||||
# balanced to both work with ONNX export and XLA. In particular XLA
|
||||
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
||||
# how to handle the dtype kwarg in cumsum.
|
||||
mask = tensor.ne(padding_idx).int()
|
||||
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
incremental_state: Optional[Any] = None,
|
||||
timestep: Optional[Tensor] = None,
|
||||
positions: Optional[Any] = None,
|
||||
):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
# bspair = torch.onnx.operators.shape_as_tensor(input)
|
||||
# bsz, seq_len = bspair[0], bspair[1]
|
||||
bsz, seq_len = input.shape[:2]
|
||||
max_pos = self.padding_idx + 1 + seq_len
|
||||
if self.weights is None or max_pos > self.weights.size(0):
|
||||
# recompute/expand embeddings if needed
|
||||
self.weights = SinusoidalPositionalEmbedding.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
|
||||
self.weights = self.weights.to(self._float_tensor)
|
||||
|
||||
if incremental_state is not None:
|
||||
# positions is the same for every token when decoding a single step
|
||||
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
||||
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
||||
|
||||
positions = make_positions(input, self.padding_idx)
|
||||
|
||||
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
||||
if max_pos > self.weight.size(0):
|
||||
# expand embeddings if needed
|
||||
self.make_weight(max_pos, self.embedding_dim, self.padding_idx, device=input.device)
|
||||
positions = self.make_positions(input, self.padding_idx)
|
||||
return super().forward(positions)
|
||||
|
||||
@@ -391,10 +391,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
|
||||
when loading the model (and avoid unnecessary warnings).
|
||||
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore
|
||||
when saving the model (useful for keys that aren't trained, but which are deterministic)
|
||||
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
authorized_missing_keys = None
|
||||
keys_to_never_save = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
@@ -688,6 +692,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
# Attach architecture to the config
|
||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Handle the case where some state_dict keys shouldn't be saved
|
||||
if self.keys_to_never_save is not None:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
|
||||
@@ -698,10 +708,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
# Save configuration file
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
# xm.save takes care of saving only from master
|
||||
xm.save(model_to_save.state_dict(), output_model_file)
|
||||
xm.save(state_dict, output_model_file)
|
||||
else:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
torch.save(state_dict, output_model_file)
|
||||
|
||||
logger.info("Model weights saved in {}".format(output_model_file))
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -20,7 +21,7 @@ import timeout_decorator # noqa
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.file_utils import WEIGHTS_NAME, cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@@ -37,6 +38,7 @@ if is_torch_available():
|
||||
invert_mask,
|
||||
shift_tokens_right,
|
||||
)
|
||||
from transformers.pipelines import TranslationPipeline
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -207,6 +209,27 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
def test_save_load_no_save_keys(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
state_dict_no_save_keys = getattr(model, "state_dict_no_save_keys", None)
|
||||
if state_dict_no_save_keys is None:
|
||||
continue
|
||||
|
||||
# check the keys are in the original state_dict
|
||||
for k in state_dict_no_save_keys:
|
||||
self.assertIn(k, model.state_dict())
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
for k in state_dict_no_save_keys:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
@@ -219,14 +242,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("failing on CI - needs review")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("failing on CI - needs review")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
# def test_auto_model(self):
|
||||
# # XXX: add a tiny model to s3?
|
||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
||||
@@ -366,6 +381,14 @@ def _long_tensor(tok_lst):
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
pairs = [
|
||||
["en-ru"],
|
||||
["ru-en"],
|
||||
["en-de"],
|
||||
["de-en"],
|
||||
]
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
tokenizers_cache = {}
|
||||
@@ -399,7 +422,7 @@ class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
src_text = "My friend computer will translate this for me"
|
||||
input_ids = tokenizer([src_text], return_tensors="pt")["input_ids"]
|
||||
input_ids = _long_tensor(input_ids)
|
||||
input_ids = _long_tensor(input_ids).to(torch_device)
|
||||
inputs_dict = prepare_fsmt_inputs_dict(model.config, input_ids)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
@@ -409,19 +432,10 @@ class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
# may have to adjust if switched to a different checkpoint
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]]
|
||||
)
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["en-ru"],
|
||||
["ru-en"],
|
||||
["en-de"],
|
||||
["de-en"],
|
||||
]
|
||||
)
|
||||
@slow
|
||||
def test_translation(self, pair):
|
||||
def translation_setup(self, pair):
|
||||
text = {
|
||||
"en": "Machine learning is great, isn't it?",
|
||||
"ru": "Машинное обучение - это здорово, не так ли?",
|
||||
@@ -432,16 +446,32 @@ class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
print(f"Testing {src} -> {tgt}")
|
||||
mname = f"facebook/wmt19-{pair}"
|
||||
|
||||
src_sentence = text[src]
|
||||
tgt_sentence = text[tgt]
|
||||
src_text = text[src]
|
||||
tgt_text = text[tgt]
|
||||
|
||||
tokenizer = self.get_tokenizer(mname)
|
||||
model = self.get_model(mname)
|
||||
return tokenizer, model, src_text, tgt_text
|
||||
|
||||
@parameterized.expand(pairs)
|
||||
@slow
|
||||
def test_translation_direct(self, pair):
|
||||
tokenizer, model, src_text, tgt_text = self.translation_setup(pair)
|
||||
|
||||
input_ids = tokenizer.encode(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
input_ids = tokenizer.encode(src_sentence, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
assert decoded == tgt_sentence, f"\n\ngot: {decoded}\nexp: {tgt_sentence}\n"
|
||||
assert decoded == tgt_text, f"\n\ngot: {decoded}\nexp: {tgt_text}\n"
|
||||
|
||||
@parameterized.expand(pairs)
|
||||
@slow
|
||||
def test_translation_pipeline(self, pair):
|
||||
tokenizer, model, src_text, tgt_text = self.translation_setup(pair)
|
||||
device = 0 if torch_device == "cuda" else -1
|
||||
pipeline = TranslationPipeline(model, tokenizer, framework="pt", device=device)
|
||||
output = pipeline([src_text])
|
||||
self.assertEqual([tgt_text], [x["translation_text"] for x in output])
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -449,10 +479,9 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
padding_idx = 1
|
||||
tolerance = 1e-4
|
||||
|
||||
@unittest.skip("failing on CI - needs review")
|
||||
def test_basic(self):
|
||||
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
|
||||
emb1 = SinusoidalPositionalEmbedding(embedding_dim=6, padding_idx=self.padding_idx, init_size=6).to(
|
||||
emb1 = SinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6, padding_idx=self.padding_idx).to(
|
||||
torch_device
|
||||
)
|
||||
emb = emb1(input_ids)
|
||||
@@ -461,7 +490,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
[9.0930e-01, 1.9999e-02, 2.0000e-04, -4.1615e-01, 9.9980e-01, 1.0000e00],
|
||||
[1.4112e-01, 2.9995e-02, 3.0000e-04, -9.8999e-01, 9.9955e-01, 1.0000e00],
|
||||
]
|
||||
)
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
torch.allclose(emb[0], desired_weights, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n",
|
||||
@@ -469,14 +498,10 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
|
||||
def test_odd_embed_dim(self):
|
||||
# odd embedding_dim is allowed
|
||||
SinusoidalPositionalEmbedding.get_embedding(
|
||||
num_embeddings=4, embedding_dim=5, padding_idx=self.padding_idx
|
||||
).to(torch_device)
|
||||
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=self.padding_idx).to(torch_device)
|
||||
|
||||
# odd num_embeddings is allowed
|
||||
SinusoidalPositionalEmbedding.get_embedding(
|
||||
num_embeddings=5, embedding_dim=4, padding_idx=self.padding_idx
|
||||
).to(torch_device)
|
||||
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=self.padding_idx).to(torch_device)
|
||||
|
||||
@unittest.skip("different from marian (needs more research)")
|
||||
def test_positional_emb_weights_against_marian(self):
|
||||
@@ -488,7 +513,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
[0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
|
||||
]
|
||||
)
|
||||
emb1 = SinusoidalPositionalEmbedding(init_size=512, embedding_dim=512, padding_idx=self.padding_idx).to(
|
||||
emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=self.padding_idx).to(
|
||||
torch_device
|
||||
)
|
||||
weights = emb1.weights.data[:3, :5]
|
||||
|
||||
Reference in New Issue
Block a user