Fix deprecation warnings for int div (#15180)

* Fix deprecation warnings for int div

Co-authored-by: mgoldey <matthew.goldey@gmail.com>

* Fix import

* ensure that tensor output is python scalar

* make backward compatible

* make code more readable

* adapt test functions

Co-authored-by: mgoldey <matthew.goldey@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2022-01-18 07:28:53 -05:00
committed by GitHub
parent f6d3fee855
commit 531336bbfd
10 changed files with 43 additions and 30 deletions

View File

@@ -302,6 +302,8 @@ class DataCollatorForWav2Vec2Pretraining:
batch_size = batch["input_values"].shape[0]
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# make sure masked sequence length is a Python scalar
mask_indices_seq_length = int(mask_indices_seq_length)
# make sure that no loss is computed on padded inputs
if batch.get("attention_mask") is not None: