[CI] Check test if the GenerationTesterMixin inheritance is correct 🐛 🔫 (#36180)
This commit is contained in:
@@ -19,6 +19,7 @@ import copy
|
||||
import datetime
|
||||
import gc
|
||||
import inspect
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
@@ -48,8 +49,6 @@ from transformers.testing_utils import (
|
||||
)
|
||||
from transformers.utils import is_ipex_available
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -2786,6 +2785,43 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertTrue(last_token_counts[8] > last_token_counts[3])
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
# Copied from tests.test_modeling_common.ids_tensor
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
# Creates a random int32 tensor of the shape within the vocab size
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()
|
||||
|
||||
|
||||
# Copied from tests.test_modeling_common.floats_tensor
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.random() * scale)
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
|
||||
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user