Time to Say Goodbye, torch 1.7 and 1.8 (#22291)
* time to say goodbye, torch 1.7 and 1.8 * clean up torch_int_div * clean up is_torch_less_than_1_8-9 * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -71,9 +71,6 @@ if is_torch_available():
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_less_than_1_9, torch_int_div
|
||||
else:
|
||||
is_torch_less_than_1_9 = True
|
||||
|
||||
|
||||
if is_torchaudio_available():
|
||||
@@ -1217,7 +1214,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
|
||||
sequence = torch.div(
|
||||
torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size, rounding_mode="floor"
|
||||
)
|
||||
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
@@ -1245,7 +1244,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
|
||||
sequence = torch.div(
|
||||
torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size, rounding_mode="floor"
|
||||
)
|
||||
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
@@ -1651,10 +1652,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_torchaudio
|
||||
@unittest.skipIf(
|
||||
is_torch_less_than_1_9,
|
||||
reason="`torchaudio.functional.resample` needs torchaudio >= 0.9 which requires torch >= 0.9",
|
||||
)
|
||||
def test_wav2vec2_with_lm(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
@@ -1679,10 +1676,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_pyctcdecode
|
||||
@require_torchaudio
|
||||
@unittest.skipIf(
|
||||
is_torch_less_than_1_9,
|
||||
reason="`torchaudio.functional.resample` needs torchaudio >= 0.9 which requires torch >= 0.9",
|
||||
)
|
||||
def test_wav2vec2_with_lm_pool(self):
|
||||
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
||||
sample = next(iter(ds))
|
||||
|
||||
Reference in New Issue
Block a user