Improve model tester (#19984)

* part 1

* part 2

* part 3

* fix

* For CANINE

* For ESMFold

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-11-02 17:38:44 +01:00
committed by GitHub
parent 7487743793
commit f69eb24b5a
16 changed files with 654 additions and 327 deletions

View File

@@ -46,26 +46,44 @@ class FSMTModelTester:
def __init__(
self,
parent,
src_vocab_size=99,
tgt_vocab_size=99,
langs=["ru", "en"],
batch_size=13,
seq_length=7,
is_training=False,
use_labels=False,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="relu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
bos_token_id=0,
pad_token_id=1,
eos_token_id=2,
):
self.parent = parent
self.src_vocab_size = 99
self.tgt_vocab_size = 99
self.langs = ["ru", "en"]
self.batch_size = 13
self.seq_length = 7
self.is_training = False
self.use_labels = False
self.hidden_size = 16
self.num_hidden_layers = 2
self.num_attention_heads = 4
self.intermediate_size = 4
self.hidden_act = "relu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20
self.bos_token_id = 0
self.pad_token_id = 1
self.eos_token_id = 2
self.src_vocab_size = src_vocab_size
self.tgt_vocab_size = tgt_vocab_size
self.langs = langs
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
torch.manual_seed(0)
# hack needed for modeling_common tests - despite not really having this attribute in this model