Improve model tester (#19984)
* part 1 * part 2 * part 3 * fix * For CANINE * For ESMFold Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -46,26 +46,44 @@ class FSMTModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
src_vocab_size=99,
|
||||
tgt_vocab_size=99,
|
||||
langs=["ru", "en"],
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=False,
|
||||
use_labels=False,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=4,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
bos_token_id=0,
|
||||
pad_token_id=1,
|
||||
eos_token_id=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.src_vocab_size = 99
|
||||
self.tgt_vocab_size = 99
|
||||
self.langs = ["ru", "en"]
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = False
|
||||
self.use_labels = False
|
||||
self.hidden_size = 16
|
||||
self.num_hidden_layers = 2
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 4
|
||||
self.hidden_act = "relu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 20
|
||||
self.bos_token_id = 0
|
||||
self.pad_token_id = 1
|
||||
self.eos_token_id = 2
|
||||
self.src_vocab_size = src_vocab_size
|
||||
self.tgt_vocab_size = tgt_vocab_size
|
||||
self.langs = langs
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
torch.manual_seed(0)
|
||||
|
||||
# hack needed for modeling_common tests - despite not really having this attribute in this model
|
||||
|
||||
Reference in New Issue
Block a user