diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1d11ef8cfa..9fd1e0a23b 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -19,6 +19,7 @@ import os.path import random import tempfile import unittest +from typing import List from transformers import is_torch_available @@ -629,10 +630,10 @@ class ModelTesterMixin: # iterate over all generative models for model_class in self.all_generative_model_classes: - model = model_class(config) + model = model_class(config).to(torch_device) if config.bos_token_id is None: - # if bos token id is not defined mobel needs input_ids + # if bos token id is not defined, model needs input_ids with self.assertRaises(AssertionError): model.generate(do_sample=True, max_length=5) # num_return_sequences = 1 @@ -651,7 +652,10 @@ class ModelTesterMixin: # check bad words tokens language generation # create list of 1-seq bad token and list of 2-seq of bad tokens - bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)] + bad_words_ids = [ + self._generate_random_bad_tokens(1, model.config), + self._generate_random_bad_tokens(2, model.config), + ] output_tokens = model.generate( input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2 ) @@ -661,10 +665,12 @@ class ModelTesterMixin: def test_lm_head_model_random_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] + input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to( + torch_device + ) for model_class in self.all_generative_model_classes: - model = model_class(config) + model = model_class(config).to(torch_device) if config.bos_token_id is None: # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1 @@ -684,7 +690,10 @@ class ModelTesterMixin: # check bad words tokens language generation # create list of 1-seq bad token and list of 2-seq of bad tokens - bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)] + bad_words_ids = [ + self._generate_random_bad_tokens(1, model.config), + self._generate_random_bad_tokens(2, model.config), + ] output_tokens = model.generate( input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2 ) @@ -692,20 +701,13 @@ class ModelTesterMixin: generated_ids = output_tokens[:, input_ids.shape[-1] :] self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids)) - def _generate_random_bad_tokens(self, num_bad_tokens, model): + def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]: # special tokens cannot be bad tokens - special_tokens = [] - if model.config.bos_token_id is not None: - special_tokens.append(model.config.bos_token_id) - if model.config.pad_token_id is not None: - special_tokens.append(model.config.pad_token_id) - if model.config.eos_token_id is not None: - special_tokens.append(model.config.eos_token_id) - + special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None] # create random bad tokens that are not special tokens bad_tokens = [] while len(bad_tokens) < num_bad_tokens: - token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).numpy()[0] + token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0] if token not in special_tokens: bad_tokens.append(token) return bad_tokens