[Seq2Seq Templates] Correct some TF-serving errors and add gradient checkpointing to PT by default. (#9334)
* correct tests * correct shape and get_tf_activation * more correction tf * add gradient checkpointing to templates * correct typo
This commit is contained in:
committed by
GitHub
parent
8e74eca7f2
commit
83fdd252f6
@@ -74,6 +74,8 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
relevant if ``config.is_decoder=True``.
|
relevant if ``config.is_decoder=True``.
|
||||||
|
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||||
{% else -%}
|
{% else -%}
|
||||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||||
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
||||||
@@ -172,6 +174,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||||||
init_std=0.02,
|
init_std=0.02,
|
||||||
decoder_start_token_id=2,
|
decoder_start_token_id=2,
|
||||||
classifier_dropout=0.0,
|
classifier_dropout=0.0,
|
||||||
|
gradient_checkpointing=False,
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
@@ -222,6 +225,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||||||
self.classifier_dropout = classifier_dropout
|
self.classifier_dropout = classifier_dropout
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.num_hidden_layers = encoder_layers
|
self.num_hidden_layers = encoder_layers
|
||||||
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
|
|
||||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers.modeling_tf_outputs import TFCausalLMOutput
|
from transformers.modeling_tf_outputs import TFCausalLMOutput
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||||
@@ -37,14 +38,14 @@ from ...modeling_tf_outputs import (
|
|||||||
TFTokenClassifierOutput,
|
TFTokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
|
TFCausalLanguageModelingLoss,
|
||||||
TFMaskedLanguageModelingLoss,
|
TFMaskedLanguageModelingLoss,
|
||||||
TFMultipleChoiceLoss,
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
TFTokenClassificationLoss,
|
|
||||||
TFCausalLanguageModelingLoss,
|
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
input_processing,
|
input_processing,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
@@ -503,7 +504,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embeddings.word_embeddings = value
|
self.embeddings.word_embeddings = value
|
||||||
self.embeddings.vocab_size = value.shape[0]
|
self.embeddings.vocab_size = shape_list(value)[0]
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
"""Prunes heads of the model.
|
"""Prunes heads of the model.
|
||||||
@@ -1109,7 +1110,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
|||||||
Returns:
|
Returns:
|
||||||
tf.Tensor with dummy inputs
|
tf.Tensor with dummy inputs
|
||||||
"""
|
"""
|
||||||
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -1404,7 +1405,7 @@ from typing import Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ...activations_tf import ACT2FN
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -1640,7 +1641,7 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = get_tf_activation(config.activation_function)
|
||||||
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
||||||
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
|
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
|
||||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||||
@@ -1689,7 +1690,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
|||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = get_tf_activation(config.activation_function)
|
||||||
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
||||||
|
|
||||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||||
@@ -1782,8 +1783,8 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
pad_token = 1
|
pad_token = 1
|
||||||
input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
|
input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
|
||||||
decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
|
decoder_input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
|
||||||
dummy_inputs = {
|
dummy_inputs = {
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": tf.math.not_equal(input_ids, pad_token),
|
"attention_mask": tf.math.not_equal(input_ids, pad_token),
|
||||||
@@ -2134,7 +2135,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
past_key_values_length = (
|
past_key_values_length = (
|
||||||
inputs["past_key_values"][0][0].shape[2] if inputs["past_key_values"] is not None else 0
|
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
@@ -2390,7 +2391,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
# {{cookiecutter.uppercase_modelname}} is a special case where the bias has two dimensions
|
# {{cookiecutter.uppercase_modelname}} is a special case where the bias has two dimensions
|
||||||
# and not named just `bias`
|
# and not named just `bias`
|
||||||
if new_num_tokens is not None:
|
if new_num_tokens is not None:
|
||||||
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
|
num_tokens_to_copy = min(shape_list(self.final_logits_bias)[0], new_num_tokens)
|
||||||
init_bias = tf.zeros((new_num_tokens,))
|
init_bias = tf.zeros((new_num_tokens,))
|
||||||
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
|
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
|
||||||
self.final_logits_bias = self.add_weight(
|
self.final_logits_bias = self.add_weight(
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -48,7 +49,6 @@ from ...modeling_utils import (
|
|||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...activations import ACT2FN
|
|
||||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||||
|
|
||||||
|
|
||||||
@@ -1809,7 +1809,13 @@ class {{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
|
|||||||
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
||||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||||
return hidden_states, attn_weights
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
||||||
@@ -1846,7 +1852,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[torch.Tensor] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1907,12 +1914,15 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
return (
|
outputs = (hidden_states,)
|
||||||
hidden_states,
|
|
||||||
self_attn_weights,
|
if output_attentions:
|
||||||
present_key_value,
|
outputs += (self_attn_weights, cross_attn_weights)
|
||||||
cross_attn_weights,
|
|
||||||
)
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
|
||||||
@@ -2178,12 +2188,28 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
dropout_probability = random.uniform(0, 1)
|
dropout_probability = random.uniform(0, 1)
|
||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
attn = None
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(encoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (attn,)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
@@ -2355,21 +2381,46 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer(
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
if use_cache:
|
||||||
|
raise ValueError(
|
||||||
|
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, use_cache)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
combined_attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (present_key_value,)
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_outputs[1],)
|
||||||
all_cross_attentions += (layer_cross_attn,)
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|||||||
@@ -532,7 +532,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCa
|
|||||||
expected_slice = tf.Tensor(
|
expected_slice = tf.Tensor(
|
||||||
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
|
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
|
||||||
)
|
)
|
||||||
self.assertTrue(tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
|
||||||
|
|
||||||
def test_inference_with_head(self):
|
def test_inference_with_head(self):
|
||||||
model = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
model = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
||||||
@@ -547,7 +547,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCa
|
|||||||
expected_slice = tf.Tensor(
|
expected_slice = tf.Tensor(
|
||||||
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
|
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]],
|
||||||
)
|
)
|
||||||
self.assertTrue(tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
|
||||||
|
|
||||||
def test_seq_to_seq_generation(self):
|
def test_seq_to_seq_generation(self):
|
||||||
hf = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
hf = TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
||||||
|
|||||||
@@ -683,23 +683,6 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, Generation
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
def test_initialization_more(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
||||||
model = {{cookiecutter.camelcase_modelname}}Model(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
# test init
|
|
||||||
self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item())
|
|
||||||
|
|
||||||
def _check_var(module):
|
|
||||||
"""Check that we initialized various parameters from N(0, config.init_std)."""
|
|
||||||
self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2)
|
|
||||||
|
|
||||||
_check_var(model.encoder.embed_tokens)
|
|
||||||
_check_var(model.encoder.layers[0].self_attn.k_proj)
|
|
||||||
_check_var(model.encoder.layers[0].fc1)
|
|
||||||
_check_var(model.encoder.embed_positions)
|
|
||||||
|
|
||||||
def test_save_load_strict(self):
|
def test_save_load_strict(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
Reference in New Issue
Block a user