From ce91bf9a3431b4d260005de84c0b0fa394409a3c Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 1 Nov 2021 22:38:52 +0530 Subject: [PATCH] [GPTJ] enable common tests and few fixes (#14190) * enable common tests, small fixes * don't tie word embeds * don't ignore lm_head --- src/transformers/models/gptj/configuration_gptj.py | 5 ++++- src/transformers/models/gptj/modeling_gptj.py | 14 +++++++------- tests/test_modeling_gptj.py | 6 ++++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 40408cb19f..6c754ddc42 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -109,6 +109,7 @@ class GPTJConfig(PretrainedConfig): use_cache=True, bos_token_id=50256, eos_token_id=50256, + tie_word_embeddings=False, **kwargs ): self.vocab_size = vocab_size @@ -130,4 +131,6 @@ class GPTJConfig(PretrainedConfig): self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + super().__init__( + bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs + ) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 0ea10b1bb0..7c01fea81d 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -71,7 +71,7 @@ class GPTJAttention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 1, 1, max_positions, max_positions ), ) @@ -136,7 +136,7 @@ class GPTJAttention(nn.Module): # compute causal mask from causal mask buffer query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) @@ -674,7 +674,7 @@ class GPTJModel(GPTJPreTrainedModel): GPTJ_START_DOCSTRING, ) class GPTJForCausalLM(GPTJPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] def __init__(self, config): super().__init__(config) @@ -707,10 +707,10 @@ class GPTJForCausalLM(GPTJPreTrainedModel): torch.cuda.empty_cache() def get_output_embeddings(self): - return None + return self.lm_head def set_output_embeddings(self, new_embeddings): - return + self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) @@ -847,13 +847,13 @@ class GPTJForCausalLM(GPTJPreTrainedModel): GPTJ_START_DOCSTRING, ) class GPTJForSequenceClassification(GPTJPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.transformer = GPTJModel(config) - self.score = nn.Linear(config.n_positions, self.num_labels, bias=False) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) self.init_weights() diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index e0ef8a905e..f443dc1af5 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -21,7 +21,8 @@ from transformers import GPTJConfig, is_torch_available from transformers.testing_utils import require_torch, slow, tooslow, torch_device from .test_configuration_common import ConfigTester -from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_torch_available(): @@ -350,7 +351,7 @@ class GPTJModelTester: @require_torch -class GPTJModelTest(unittest.TestCase): +class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () @@ -358,6 +359,7 @@ class GPTJModelTest(unittest.TestCase): test_pruning = False test_missing_keys = False test_model_parallel = False + test_head_masking = False # special case for DoubleHeads model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):