[fsmt] rewrite SinusoidalPositionalEmbedding + USE_CUDA test fixes + new TranslationPipeline test (#7224)
* fix USE_CUDA, add pipeline * USE_CUDA fix * recode SinusoidalPositionalEmbedding into nn.Embedding subclass was needed for torchscript to work - this is now part of the state_dict, so will have to remove these keys during save_pretrained * back out (ci debug) * restore * slow last? * facilitate not saving certain keys and test * remove no longer used keys * style * fix logging import * cleanup * Update src/transformers/modeling_utils.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * fix bug in max_positional_embeddings * rename keys to keys_to_never_save per suggestion, improve the setup * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -20,7 +21,7 @@ import timeout_decorator # noqa
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.file_utils import WEIGHTS_NAME, cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@@ -37,6 +38,7 @@ if is_torch_available():
|
||||
invert_mask,
|
||||
shift_tokens_right,
|
||||
)
|
||||
from transformers.pipelines import TranslationPipeline
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -207,6 +209,27 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
def test_save_load_no_save_keys(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
state_dict_no_save_keys = getattr(model, "state_dict_no_save_keys", None)
|
||||
if state_dict_no_save_keys is None:
|
||||
continue
|
||||
|
||||
# check the keys are in the original state_dict
|
||||
for k in state_dict_no_save_keys:
|
||||
self.assertIn(k, model.state_dict())
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
for k in state_dict_no_save_keys:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
@@ -219,14 +242,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("failing on CI - needs review")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("failing on CI - needs review")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
# def test_auto_model(self):
|
||||
# # XXX: add a tiny model to s3?
|
||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
||||
@@ -366,6 +381,14 @@ def _long_tensor(tok_lst):
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
pairs = [
|
||||
["en-ru"],
|
||||
["ru-en"],
|
||||
["en-de"],
|
||||
["de-en"],
|
||||
]
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
tokenizers_cache = {}
|
||||
@@ -399,7 +422,7 @@ class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
src_text = "My friend computer will translate this for me"
|
||||
input_ids = tokenizer([src_text], return_tensors="pt")["input_ids"]
|
||||
input_ids = _long_tensor(input_ids)
|
||||
input_ids = _long_tensor(input_ids).to(torch_device)
|
||||
inputs_dict = prepare_fsmt_inputs_dict(model.config, input_ids)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
@@ -409,19 +432,10 @@ class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
# may have to adjust if switched to a different checkpoint
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.5753, -1.5753, 2.8975], [-0.9540, -0.9540, 1.0299], [-3.3131, -3.3131, 0.5219]]
|
||||
)
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["en-ru"],
|
||||
["ru-en"],
|
||||
["en-de"],
|
||||
["de-en"],
|
||||
]
|
||||
)
|
||||
@slow
|
||||
def test_translation(self, pair):
|
||||
def translation_setup(self, pair):
|
||||
text = {
|
||||
"en": "Machine learning is great, isn't it?",
|
||||
"ru": "Машинное обучение - это здорово, не так ли?",
|
||||
@@ -432,16 +446,32 @@ class FSMTModelIntegrationTests(unittest.TestCase):
|
||||
print(f"Testing {src} -> {tgt}")
|
||||
mname = f"facebook/wmt19-{pair}"
|
||||
|
||||
src_sentence = text[src]
|
||||
tgt_sentence = text[tgt]
|
||||
src_text = text[src]
|
||||
tgt_text = text[tgt]
|
||||
|
||||
tokenizer = self.get_tokenizer(mname)
|
||||
model = self.get_model(mname)
|
||||
return tokenizer, model, src_text, tgt_text
|
||||
|
||||
@parameterized.expand(pairs)
|
||||
@slow
|
||||
def test_translation_direct(self, pair):
|
||||
tokenizer, model, src_text, tgt_text = self.translation_setup(pair)
|
||||
|
||||
input_ids = tokenizer.encode(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
input_ids = tokenizer.encode(src_sentence, return_tensors="pt")
|
||||
outputs = model.generate(input_ids)
|
||||
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
assert decoded == tgt_sentence, f"\n\ngot: {decoded}\nexp: {tgt_sentence}\n"
|
||||
assert decoded == tgt_text, f"\n\ngot: {decoded}\nexp: {tgt_text}\n"
|
||||
|
||||
@parameterized.expand(pairs)
|
||||
@slow
|
||||
def test_translation_pipeline(self, pair):
|
||||
tokenizer, model, src_text, tgt_text = self.translation_setup(pair)
|
||||
device = 0 if torch_device == "cuda" else -1
|
||||
pipeline = TranslationPipeline(model, tokenizer, framework="pt", device=device)
|
||||
output = pipeline([src_text])
|
||||
self.assertEqual([tgt_text], [x["translation_text"] for x in output])
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -449,10 +479,9 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
padding_idx = 1
|
||||
tolerance = 1e-4
|
||||
|
||||
@unittest.skip("failing on CI - needs review")
|
||||
def test_basic(self):
|
||||
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
|
||||
emb1 = SinusoidalPositionalEmbedding(embedding_dim=6, padding_idx=self.padding_idx, init_size=6).to(
|
||||
emb1 = SinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6, padding_idx=self.padding_idx).to(
|
||||
torch_device
|
||||
)
|
||||
emb = emb1(input_ids)
|
||||
@@ -461,7 +490,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
[9.0930e-01, 1.9999e-02, 2.0000e-04, -4.1615e-01, 9.9980e-01, 1.0000e00],
|
||||
[1.4112e-01, 2.9995e-02, 3.0000e-04, -9.8999e-01, 9.9955e-01, 1.0000e00],
|
||||
]
|
||||
)
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
torch.allclose(emb[0], desired_weights, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n",
|
||||
@@ -469,14 +498,10 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
|
||||
def test_odd_embed_dim(self):
|
||||
# odd embedding_dim is allowed
|
||||
SinusoidalPositionalEmbedding.get_embedding(
|
||||
num_embeddings=4, embedding_dim=5, padding_idx=self.padding_idx
|
||||
).to(torch_device)
|
||||
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=self.padding_idx).to(torch_device)
|
||||
|
||||
# odd num_embeddings is allowed
|
||||
SinusoidalPositionalEmbedding.get_embedding(
|
||||
num_embeddings=5, embedding_dim=4, padding_idx=self.padding_idx
|
||||
).to(torch_device)
|
||||
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=self.padding_idx).to(torch_device)
|
||||
|
||||
@unittest.skip("different from marian (needs more research)")
|
||||
def test_positional_emb_weights_against_marian(self):
|
||||
@@ -488,7 +513,7 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
[0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
|
||||
]
|
||||
)
|
||||
emb1 = SinusoidalPositionalEmbedding(init_size=512, embedding_dim=512, padding_idx=self.padding_idx).to(
|
||||
emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=self.padding_idx).to(
|
||||
torch_device
|
||||
)
|
||||
weights = emb1.weights.data[:3, :5]
|
||||
|
||||
Reference in New Issue
Block a user