Style
This commit is contained in:
@@ -52,35 +52,35 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
class FlaubertModelTester(object):
|
class FlaubertModelTester(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_input_lengths=True,
|
use_input_lengths=True,
|
||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
gelu_activation=True,
|
gelu_activation=True,
|
||||||
sinusoidal_embeddings=False,
|
sinusoidal_embeddings=False,
|
||||||
causal=False,
|
causal=False,
|
||||||
asm=False,
|
asm=False,
|
||||||
n_langs=2,
|
n_langs=2,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
n_special=0,
|
n_special=0,
|
||||||
hidden_size=32,
|
hidden_size=32,
|
||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_probs_dropout_prob=0.1,
|
attention_probs_dropout_prob=0.1,
|
||||||
max_position_embeddings=512,
|
max_position_embeddings=512,
|
||||||
type_vocab_size=16,
|
type_vocab_size=16,
|
||||||
type_sequence_label_size=2,
|
type_sequence_label_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
num_choices=4,
|
num_choices=4,
|
||||||
summary_type="last",
|
summary_type="last",
|
||||||
use_proj=True,
|
use_proj=True,
|
||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -119,7 +119,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
input_lengths = None
|
input_lengths = None
|
||||||
if self.use_input_lengths:
|
if self.use_input_lengths:
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
|
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
|
||||||
) # small variation of seq_length
|
) # small variation of seq_length
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
@@ -168,15 +168,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
|
|
||||||
def create_and_check_flaubert_model(
|
def create_and_check_flaubert_model(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
is_impossible_labels,
|
is_impossible_labels,
|
||||||
input_mask,
|
input_mask,
|
||||||
):
|
):
|
||||||
model = FlaubertModel(config=config)
|
model = FlaubertModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -193,15 +193,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_flaubert_lm_head(
|
def create_and_check_flaubert_lm_head(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
is_impossible_labels,
|
is_impossible_labels,
|
||||||
input_mask,
|
input_mask,
|
||||||
):
|
):
|
||||||
model = FlaubertWithLMHeadModel(config)
|
model = FlaubertWithLMHeadModel(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -220,15 +220,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_flaubert_simple_qa(
|
def create_and_check_flaubert_simple_qa(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
is_impossible_labels,
|
is_impossible_labels,
|
||||||
input_mask,
|
input_mask,
|
||||||
):
|
):
|
||||||
model = FlaubertForQuestionAnsweringSimple(config)
|
model = FlaubertForQuestionAnsweringSimple(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -249,15 +249,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
def create_and_check_flaubert_qa(
|
def create_and_check_flaubert_qa(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
is_impossible_labels,
|
is_impossible_labels,
|
||||||
input_mask,
|
input_mask,
|
||||||
):
|
):
|
||||||
model = FlaubertForQuestionAnswering(config)
|
model = FlaubertForQuestionAnswering(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -316,15 +316,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
|
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
|
||||||
|
|
||||||
def create_and_check_flaubert_sequence_classif(
|
def create_and_check_flaubert_sequence_classif(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
token_type_ids,
|
token_type_ids,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
is_impossible_labels,
|
is_impossible_labels,
|
||||||
input_mask,
|
input_mask,
|
||||||
):
|
):
|
||||||
model = FlaubertForSequenceClassification(config)
|
model = FlaubertForSequenceClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
def create_and_check_roberta_for_multiple_choice(
|
def create_and_check_roberta_for_multiple_choice(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
config.num_choices = self.num_choices
|
config.num_choices = self.num_choices
|
||||||
model = RobertaForMultipleChoice(config=config)
|
model = RobertaForMultipleChoice(config=config)
|
||||||
@@ -208,7 +208,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
def create_and_check_roberta_for_question_answering(
|
def create_and_check_roberta_for_question_answering(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
model = RobertaForQuestionAnswering(config=config)
|
model = RobertaForQuestionAnswering(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user