ALBERT passes all tests
This commit is contained in:
@@ -7,7 +7,7 @@ class AlbertConfig(PretrainedConfig):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
vocab_size_or_config_json_file=30000,
|
||||
embedding_size=128,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=12,
|
||||
@@ -15,7 +15,6 @@ class AlbertConfig(PretrainedConfig):
|
||||
num_attention_heads=64,
|
||||
intermediate_size=16384,
|
||||
inner_group_num=1,
|
||||
down_scale_factor=1,
|
||||
hidden_act="gelu_new",
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0,
|
||||
@@ -61,7 +60,6 @@ class AlbertConfig(PretrainedConfig):
|
||||
self.num_hidden_groups = num_hidden_groups
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_group_num = inner_group_num
|
||||
self.down_scale_factor = down_scale_factor
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
|
||||
@@ -202,9 +202,6 @@ class AlbertLayerGroup(nn.Module):
|
||||
layer_attentions = ()
|
||||
|
||||
for albert_layer in self.albert_layers:
|
||||
if self.output_hidden_states:
|
||||
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
||||
|
||||
layer_output = albert_layer(hidden_states, attention_mask, head_mask)
|
||||
hidden_states = layer_output[0]
|
||||
|
||||
@@ -247,7 +244,7 @@ class AlbertTransformer(nn.Module):
|
||||
hidden_states = layer_group_output[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + layer_group_output[1]
|
||||
all_attentions = all_attentions + layer_group_output[-1]
|
||||
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers.tokenization_albert import (AlbertTokenizer, SPIECE_UNDERLINE)
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fixtures/30k-clean.model')
|
||||
'fixtures/spiece.model')
|
||||
|
||||
class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user