Update tiny model creation script and some others files (#22006)
* Update 1 * Update 2 * Update 3 * Update 4 * Update 5 * Update 6 * Update 7 * Update 8 * Update 9 * Update 10 --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class OneFormerModelTester:
|
||||
parent,
|
||||
batch_size=2,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
use_auxiliary_loss=False,
|
||||
num_queries=10,
|
||||
num_channels=3,
|
||||
@@ -69,6 +70,7 @@ class OneFormerModelTester:
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.is_training = is_training
|
||||
self.vocab_size = vocab_size
|
||||
self.use_auxiliary_loss = use_auxiliary_loss
|
||||
self.num_queries = num_queries
|
||||
self.num_channels = num_channels
|
||||
@@ -84,12 +86,16 @@ class OneFormerModelTester:
|
||||
torch_device
|
||||
)
|
||||
|
||||
task_inputs = torch.randint(high=49408, size=(self.batch_size, self.sequence_length)).to(torch_device).long()
|
||||
task_inputs = (
|
||||
torch.randint(high=self.vocab_size, size=(self.batch_size, self.sequence_length)).to(torch_device).long()
|
||||
)
|
||||
|
||||
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
|
||||
|
||||
text_inputs = (
|
||||
torch.randint(high=49408, size=(self.batch_size, self.num_queries - self.n_ctx, self.sequence_length))
|
||||
torch.randint(
|
||||
high=self.vocab_size, size=(self.batch_size, self.num_queries - self.n_ctx, self.sequence_length)
|
||||
)
|
||||
.to(torch_device)
|
||||
.long()
|
||||
)
|
||||
@@ -104,6 +110,7 @@ class OneFormerModelTester:
|
||||
|
||||
def get_config(self):
|
||||
config = OneFormerConfig(
|
||||
text_encoder_vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_dim,
|
||||
)
|
||||
|
||||
@@ -303,8 +310,10 @@ class OneFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
size = (self.model_tester.min_size,) * 2
|
||||
inputs = {
|
||||
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
|
||||
"task_inputs": torch.randint(high=49408, size=(2, 77), device=torch_device).long(),
|
||||
"text_inputs": torch.randint(high=49408, size=(2, 134, 77), device=torch_device).long(),
|
||||
"task_inputs": torch.randint(high=self.model_tester.vocab_size, size=(2, 77), device=torch_device).long(),
|
||||
"text_inputs": torch.randint(
|
||||
high=self.model_tester.vocab_size, size=(2, 134, 77), device=torch_device
|
||||
).long(),
|
||||
"mask_labels": torch.randn((2, 150, *size), device=torch_device),
|
||||
"class_labels": torch.zeros(2, 150, device=torch_device).long(),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user