diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 6bb5169536..95d5c91f5a 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -109,6 +109,10 @@ def _compute_mask_indices( # scatter indices to mask spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True) + if attention_mask is not None: + # make sure padded input ids cannot be masked + spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False) + return spec_aug_mask diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index f37c2d77cc..0d5df863e0 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -258,7 +258,7 @@ def _compute_mask_indices( tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape ) - return tf.cast(spec_aug_mask, tf.float32) + return spec_aug_mask # Copied from transformers.models.bart.modeling_tf_bart._expand_mask diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 41f112a11e..0de7be0846 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -508,9 +508,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer): if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif inputs["input_ids"] is not None: - input_shape = shape_list(tensor=inputs["input_ids"]) + input_shape = shape_list(inputs["input_ids"]) elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(tensor=inputs["inputs_embeds"])[:-1] + input_shape = shape_list(inputs["inputs_embeds"])[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") diff --git a/utils/check_copies.py b/utils/check_copies.py index c08fc82269..9d19fba518 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -52,7 +52,7 @@ LOCALIZED_READMES = { def _should_continue(line, indent): - return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None + return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None def find_code_in_transformers(object_name):