Update tiny model creation script and some others files (#22006)

* Update 1

* Update 2

* Update 3

* Update 4

* Update 5

* Update 6

* Update 7

* Update 8

* Update 9

* Update 10

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-03-07 22:31:14 +01:00
committed by GitHub
parent 8abe4930d3
commit b338414e61
11 changed files with 36 additions and 14 deletions

View File

@@ -56,6 +56,7 @@ class OneFormerModelTester:
parent,
batch_size=2,
is_training=True,
vocab_size=99,
use_auxiliary_loss=False,
num_queries=10,
num_channels=3,
@@ -69,6 +70,7 @@ class OneFormerModelTester:
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.vocab_size = vocab_size
self.use_auxiliary_loss = use_auxiliary_loss
self.num_queries = num_queries
self.num_channels = num_channels
@@ -84,12 +86,16 @@ class OneFormerModelTester:
torch_device
)
task_inputs = torch.randint(high=49408, size=(self.batch_size, self.sequence_length)).to(torch_device).long()
task_inputs = (
torch.randint(high=self.vocab_size, size=(self.batch_size, self.sequence_length)).to(torch_device).long()
)
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
text_inputs = (
torch.randint(high=49408, size=(self.batch_size, self.num_queries - self.n_ctx, self.sequence_length))
torch.randint(
high=self.vocab_size, size=(self.batch_size, self.num_queries - self.n_ctx, self.sequence_length)
)
.to(torch_device)
.long()
)
@@ -104,6 +110,7 @@ class OneFormerModelTester:
def get_config(self):
config = OneFormerConfig(
text_encoder_vocab_size=self.vocab_size,
hidden_size=self.hidden_dim,
)
@@ -303,8 +310,10 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
size = (self.model_tester.min_size,) * 2
inputs = {
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
"task_inputs": torch.randint(high=49408, size=(2, 77), device=torch_device).long(),
"text_inputs": torch.randint(high=49408, size=(2, 134, 77), device=torch_device).long(),
"task_inputs": torch.randint(high=self.model_tester.vocab_size, size=(2, 77), device=torch_device).long(),
"text_inputs": torch.randint(
high=self.model_tester.vocab_size, size=(2, 134, 77), device=torch_device
).long(),
"mask_labels": torch.randn((2, 150, *size), device=torch_device),
"class_labels": torch.zeros(2, 150, device=torch_device).long(),
}

View File

@@ -103,6 +103,7 @@ class SpeechT5ModelTester:
batch_size=13,
seq_length=7,
is_training=False,
vocab_size=81,
hidden_size=24,
num_hidden_layers=4,
num_attention_heads=2,
@@ -112,6 +113,7 @@ class SpeechT5ModelTester:
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
@@ -140,6 +142,7 @@ class SpeechT5ModelTester:
def get_config(self):
return SpeechT5Config(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,