Fix deprecation warnings for int div (#15180)
* Fix deprecation warnings for int div Co-authored-by: mgoldey <matthew.goldey@gmail.com> * Fix import * ensure that tensor output is python scalar * make backward compatible * make code more readable * adapt test functions Co-authored-by: mgoldey <matthew.goldey@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -302,6 +302,8 @@ class DataCollatorForWav2Vec2Pretraining:
|
||||
batch_size = batch["input_values"].shape[0]
|
||||
|
||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||
# make sure masked sequence length is a Python scalar
|
||||
mask_indices_seq_length = int(mask_indices_seq_length)
|
||||
|
||||
# make sure that no loss is computed on padded inputs
|
||||
if batch.get("attention_mask") is not None:
|
||||
|
||||
Reference in New Issue
Block a user