diff --git a/tests/test_modeling_flaubert.py b/tests/test_modeling_flaubert.py index 1d3daa2cab..8f4b882821 100644 --- a/tests/test_modeling_flaubert.py +++ b/tests/test_modeling_flaubert.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os +import tempfile import unittest from transformers import FlaubertConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -401,6 +402,29 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): model = FlaubertModel.from_pretrained(model_name) self.assertIsNotNone(model) + @slow + @require_torch_gpu + def test_torchscript_device_change(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + + # FlauBertForMultipleChoice behaves incorrectly in JIT environments. + if model_class == FlaubertForMultipleChoice: + return + + config.torchscript = True + model = model_class(config=config) + + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + traced_model = torch.jit.trace( + model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu")) + ) + + with tempfile.TemporaryDirectory() as tmp: + torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt")) + loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) + loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + @require_torch class FlaubertModelIntegrationTest(unittest.TestCase):