[Sequence Feature Extraction] Add truncation (#12804)

* fix_torch_device_generate_test

* remove @

* add truncate

* finish

* correct test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* clean tests

* correct normalization for truncation

* remove casting

* up

* save intermed

* finish

* finish

* correct

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-07-23 17:53:30 +02:00
committed by GitHub
parent 98364ea74f
commit f6e254474c
8 changed files with 370 additions and 38 deletions

View File

@@ -715,7 +715,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
@@ -737,7 +737,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)