From f0fd73a2de0a611a3826885c0f529493ab32ace0 Mon Sep 17 00:00:00 2001
From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Date: Fri, 4 Aug 2023 14:56:29 +0200
Subject: [PATCH] Document check copies (#25291)
* Document check copies better and add tests
* Include header in check for copies
* Manual fixes
* Try autofix
* Fixes
* Clean tests
* Finalize doc
* Remove debug print
* More fixes
---
docs/source/en/add_new_model.md | 2 +-
docs/source/en/pr_checks.md | 55 +++++
src/transformers/__init__.py | 2 +
.../models/albert/modeling_flax_albert.py | 1 -
.../models/albert/tokenization_albert.py | 4 +-
.../models/align/modeling_align.py | 2 +-
src/transformers/models/bart/__init__.py | 2 +
src/transformers/models/bart/modeling_bart.py | 30 ++-
src/transformers/models/bit/modeling_bit.py | 2 +-
.../blenderbot/modeling_flax_blenderbot.py | 3 +-
.../modeling_flax_blenderbot_small.py | 2 +-
.../models/clipseg/modeling_clipseg.py | 2 +-
.../image_processing_conditional_detr.py | 2 +-
.../models/convnext/modeling_convnext.py | 2 +-
.../models/convnextv2/modeling_convnextv2.py | 2 +-
src/transformers/models/cvt/modeling_cvt.py | 2 +-
.../models/deprecated/van/modeling_van.py | 2 +-
.../models/dinat/modeling_dinat.py | 2 +-
.../models/dinov2/modeling_dinov2.py | 2 +-
.../models/donut/modeling_donut_swin.py | 4 +-
.../modeling_efficientformer.py | 2 +-
.../models/esm/modeling_tf_esm.py | 2 +-
.../models/focalnet/modeling_focalnet.py | 2 +-
src/transformers/models/glpn/modeling_glpn.py | 4 +-
.../gpt_bigcode/modeling_gpt_bigcode.py | 2 +-
.../models/llama/tokenization_llama.py | 7 +-
.../models/longt5/modeling_flax_longt5.py | 2 +-
.../models/marian/modeling_flax_marian.py | 2 +-
.../maskformer/modeling_maskformer_swin.py | 2 +-
.../models/mgp_str/modeling_mgp_str.py | 2 +-
src/transformers/models/nat/modeling_nat.py | 2 +-
.../models/owlvit/image_processing_owlvit.py | 1 -
.../models/owlvit/modeling_owlvit.py | 8 +-
.../models/pegasus/modeling_flax_pegasus.py | 12 +-
.../models/poolformer/modeling_poolformer.py | 2 +-
src/transformers/models/pvt/modeling_pvt.py | 4 +-
.../models/segformer/modeling_segformer.py | 4 +-
.../swiftformer/modeling_swiftformer.py | 2 +-
src/transformers/models/swin/modeling_swin.py | 2 +-
.../models/swin2sr/modeling_swin2sr.py | 4 +-
.../models/swinv2/modeling_swinv2.py | 2 +-
.../models/t5/modeling_flax_t5.py | 2 +-
.../modeling_wav2vec2_conformer.py | 2 +-
.../models/whisper/tokenization_whisper.py | 8 +-
.../whisper/tokenization_whisper_fast.py | 8 +-
.../models/x_clip/modeling_x_clip.py | 2 +-
.../models/xglm/modeling_tf_xglm.py | 2 +-
.../models/yolos/image_processing_yolos.py | 2 +-
src/transformers/utils/dummy_pt_objects.py | 7 +
tests/repo_utils/test_check_copies.py | 197 ++++++++++++------
utils/check_copies.py | 123 +++++++++--
51 files changed, 382 insertions(+), 166 deletions(-)
diff --git a/docs/source/en/add_new_model.md b/docs/source/en/add_new_model.md
index b330535408..4072be6f59 100644
--- a/docs/source/en/add_new_model.md
+++ b/docs/source/en/add_new_model.md
@@ -101,7 +101,7 @@ own regarding how code should be written :-)
1. The forward pass of your model should be fully written in the modeling file while being fully independent of other
models in the library. If you want to reuse a block from another model, copy the code and paste it with a
`# Copied from` comment on top (see [here](https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/roberta/modeling_roberta.py#L160)
- for a good example).
+ for a good example and [there](pr_checks#check-copies) for more documentation on Copied from).
2. The code should be fully understandable, even by a non-native English speaker. This means you should pick
descriptive variable names and avoid abbreviations. As an example, `activation` is preferred to `act`.
One-letter variable names are strongly discouraged unless it's an index in a for loop.
diff --git a/docs/source/en/pr_checks.md b/docs/source/en/pr_checks.md
index 6aeee89d75..c5a2e539c0 100644
--- a/docs/source/en/pr_checks.md
+++ b/docs/source/en/pr_checks.md
@@ -142,3 +142,58 @@ Additional checks concern PRs that add new models, mainly that:
- All checkpoints used actually exist on the Hub
-->
+
+### Check copies
+
+Since the Transformers library is very opinionated with respect to model code, and each model should fully be implemented in a single file without relying on other models, we have added a mechanism that checks whether a copy of the code of a layer of a given model stays consistent with the original. This way, when there is a bug fix, we can see all other impacted models and choose to trickle down the modification or break the copy.
+
+
+
+If a file is a full copy of another file, you should register it in the constant `FULL_COPIES` of `utils/check_copies.py`.
+
+
+
+This mechanism relies on comments of the form `# Copied from xxx`. The `xxx` should contain the whole path to the class of function which is being copied below. For instance, `RobertaSelfOutput` is a direct copy of the `BertSelfOutput` class, so you can see [here](https://github.com/huggingface/transformers/blob/2bd7a27a671fd1d98059124024f580f8f5c0f3b5/src/transformers/models/roberta/modeling_roberta.py#L289) it has a comment:
+
+```py
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+```
+
+Note that instead of applying this to a whole class, you can apply it to the relevant methods that are copied from. For instance [here](https://github.com/huggingface/transformers/blob/2bd7a27a671fd1d98059124024f580f8f5c0f3b5/src/transformers/models/roberta/modeling_roberta.py#L598) you can see how `RobertaPreTrainedModel._init_weights` is copied from the same method in `BertPreTrainedModel` with the comment:
+
+```py
+# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
+```
+
+Sometimes the copy is exactly the same except for names: for instance in `RobertaAttention`, we use `RobertaSelfAttention` insted of `BertSelfAttention` but other than that, the code is exactly the same. This is why `# Copied from` supports simple string replacements with the follwoing syntax: `Copied from xxx with foo->bar`. This means the code is copied with all instances of `foo` being replaced by `bar`. You can see how it used [here](https://github.com/huggingface/transformers/blob/2bd7a27a671fd1d98059124024f580f8f5c0f3b5/src/transformers/models/roberta/modeling_roberta.py#L304C1-L304C86) in `RobertaAttention` with the comment:
+
+```py
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
+```
+
+Note that there shouldn't be any spaces around the arrow (unless that space is part of the pattern to replace of course).
+
+You can add several patterns separated by a comma. For instance here `CamemberForMaskedLM` is a direct copy of `RobertaForMaskedLM` with two replacements: `Roberta` to `Camembert` and `ROBERTA` to `CAMEMBERT`. You can see [here](https://github.com/huggingface/transformers/blob/15082a9dc6950ecae63a0d3e5060b2fc7f15050a/src/transformers/models/camembert/modeling_camembert.py#L929) this is done with the comment:
+
+```py
+# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM with Roberta->Camembert, ROBERTA->CAMEMBERT
+```
+
+If the order matters (because one of the replacements might conflict with a previous one), the replacements are executed from left to right.
+
+
+
+If the replacements change the formatting (if you replace a short name by a very long name for instance), the copy is checked after applying the auto-formatter.
+
+
+
+Another way when the patterns are just different casings of the same replacement (with an uppercased and a lowercased variants) is just to add the option `all-casing`. [Here](https://github.com/huggingface/transformers/blob/15082a9dc6950ecae63a0d3e5060b2fc7f15050a/src/transformers/models/mobilebert/modeling_mobilebert.py#L1237) is an example in `MobileBertForSequenceClassification` with the comment:
+
+```py
+# Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
+```
+
+In this case, the code is copied from `BertForSequenceClassification` by replacing:
+- `Bert` by `MobileBert` (for instance when using `MobileBertModel` in the init)
+- `bert` by `mobilebert` (for instance when defining `self.mobilebert`)
+- `BERT` by `MOBILEBERT` (in the constant `MOBILEBERT_INPUTS_DOCSTRING`)
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 7994a5b6ed..2253bda390 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1168,6 +1168,7 @@ else:
"BartForSequenceClassification",
"BartModel",
"BartPretrainedModel",
+ "BartPreTrainedModel",
"PretrainedBartModel",
]
)
@@ -5072,6 +5073,7 @@ if TYPE_CHECKING:
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
+ BartPreTrainedModel,
BartPretrainedModel,
PretrainedBartModel,
)
diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py
index 0ff1b9276a..55fd9d5a4c 100644
--- a/src/transformers/models/albert/modeling_flax_albert.py
+++ b/src/transformers/models/albert/modeling_flax_albert.py
@@ -173,7 +173,6 @@ class FlaxAlbertEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
- # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__
def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
diff --git a/src/transformers/models/albert/tokenization_albert.py b/src/transformers/models/albert/tokenization_albert.py
index b043a14989..231abf1c03 100644
--- a/src/transformers/models/albert/tokenization_albert.py
+++ b/src/transformers/models/albert/tokenization_albert.py
@@ -183,10 +183,10 @@ class AlbertTokenizer(PreTrainedTokenizer):
self.sp_model.Load(vocab_file)
@property
- def vocab_size(self):
+ def vocab_size(self) -> int:
return len(self.sp_model)
- def get_vocab(self):
+ def get_vocab(self) -> Dict[str, int]:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py
index a60be94224..4c5f1d3138 100644
--- a/src/transformers/models/align/modeling_align.py
+++ b/src/transformers/models/align/modeling_align.py
@@ -286,7 +286,7 @@ def align_loss(similarity: torch.Tensor) -> torch.Tensor:
return (caption_loss + image_loss) / 2.0
-# Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet -> AlignVision
+# Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet->AlignVision
def round_filters(config: AlignVisionConfig, num_channels: int):
r"""
Round number of filters based on depth multiplier.
diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py
index 7129474b4e..4f104efce1 100644
--- a/src/transformers/models/bart/__init__.py
+++ b/src/transformers/models/bart/__init__.py
@@ -49,6 +49,7 @@ else:
"BartForQuestionAnswering",
"BartForSequenceClassification",
"BartModel",
+ "BartPreTrainedModel",
"BartPretrainedModel",
"PretrainedBartModel",
]
@@ -107,6 +108,7 @@ if TYPE_CHECKING:
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
+ BartPreTrainedModel,
BartPretrainedModel,
PretrainedBartModel,
)
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index 91f7bac906..fe3fbb1f8c 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -502,7 +502,7 @@ class BartClassificationHead(nn.Module):
return hidden_states
-class BartPretrainedModel(PreTrainedModel):
+class BartPreTrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
@@ -536,10 +536,18 @@ class BartPretrainedModel(PreTrainedModel):
return dummy_inputs
-class PretrainedBartModel(BartPretrainedModel):
+class PretrainedBartModel(BartPreTrainedModel):
def __init_subclass__(self):
warnings.warn(
- "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
+ FutureWarning,
+ )
+
+
+class BartPretrainedModel(BartPreTrainedModel):
+ def __init_subclass__(self):
+ warnings.warn(
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
FutureWarning,
)
@@ -700,7 +708,7 @@ BART_INPUTS_DOCSTRING = r"""
"""
-class BartEncoder(BartPretrainedModel):
+class BartEncoder(BartPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`BartEncoderLayer`].
@@ -882,7 +890,7 @@ class BartEncoder(BartPretrainedModel):
)
-class BartDecoder(BartPretrainedModel):
+class BartDecoder(BartPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
@@ -1169,7 +1177,7 @@ class BartDecoder(BartPretrainedModel):
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
-class BartModel(BartPretrainedModel):
+class BartModel(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig):
@@ -1296,7 +1304,7 @@ class BartModel(BartPretrainedModel):
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
-class BartForConditionalGeneration(BartPretrainedModel):
+class BartForConditionalGeneration(BartPreTrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
@@ -1471,7 +1479,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
""",
BART_START_DOCSTRING,
)
-class BartForSequenceClassification(BartPretrainedModel):
+class BartForSequenceClassification(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs):
@@ -1601,7 +1609,7 @@ class BartForSequenceClassification(BartPretrainedModel):
""",
BART_START_DOCSTRING,
)
-class BartForQuestionAnswering(BartPretrainedModel):
+class BartForQuestionAnswering(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
@@ -1719,7 +1727,7 @@ class BartForQuestionAnswering(BartPretrainedModel):
)
-class BartDecoderWrapper(BartPretrainedModel):
+class BartDecoderWrapper(BartPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
@@ -1739,7 +1747,7 @@ class BartDecoderWrapper(BartPretrainedModel):
""",
BART_START_DOCSTRING,
)
-class BartForCausalLM(BartPretrainedModel):
+class BartForCausalLM(BartPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py
index 284ff5e2de..12a5ecd42b 100644
--- a/src/transformers/models/bit/modeling_bit.py
+++ b/src/transformers/models/bit/modeling_bit.py
@@ -300,7 +300,7 @@ class BitEmbeddings(nn.Module):
# Copied from transformers.models.convnext.modeling_convnext.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
index 6796f48163..3f5c73a6c3 100644
--- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
@@ -22,7 +22,6 @@ from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
-import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
@@ -205,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
index e13b90c060..77e6b1704b 100644
--- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
@@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py
index 4cab4425f1..3dc006179c 100644
--- a/src/transformers/models/clipseg/modeling_clipseg.py
+++ b/src/transformers/models/clipseg/modeling_clipseg.py
@@ -160,7 +160,7 @@ class CLIPSegImageSegmentationOutput(ModelOutput):
class CLIPSegVisionEmbeddings(nn.Module):
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
def __init__(self, config: CLIPSegVisionConfig):
super().__init__()
self.config = config
diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py
index 5bb90d5d74..4f6497d112 100644
--- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py
+++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py
@@ -861,7 +861,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
return target
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare
- def prepare(self, image, target, return_segmentation_masks=False, masks_path=None):
+ def prepare(self, image, target, return_segmentation_masks=None, masks_path=None):
logger.warning_once(
"The `prepare` method is deprecated and will be removed in a v4.33. "
"Please use `prepare_annotation` instead. Note: the `prepare_annotation` method "
diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py
index 3733fb9414..e6cf336517 100755
--- a/src/transformers/models/convnext/modeling_convnext.py
+++ b/src/transformers/models/convnext/modeling_convnext.py
@@ -61,7 +61,7 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py
index 70c35a85af..3a268c713d 100644
--- a/src/transformers/models/convnextv2/modeling_convnextv2.py
+++ b/src/transformers/models/convnextv2/modeling_convnextv2.py
@@ -61,7 +61,7 @@ CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py
index 99e3a02feb..d21b5c9a87 100644
--- a/src/transformers/models/cvt/modeling_cvt.py
+++ b/src/transformers/models/cvt/modeling_cvt.py
@@ -78,7 +78,7 @@ class BaseModelOutputWithCLSToken(ModelOutput):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py
index f7feebae4d..4ef18f5415 100644
--- a/src/transformers/models/deprecated/van/modeling_van.py
+++ b/src/transformers/models/deprecated/van/modeling_van.py
@@ -54,7 +54,7 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.convnext.modeling_convnext.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py
index b15d7d187e..89c6ed2e2a 100644
--- a/src/transformers/models/dinat/modeling_dinat.py
+++ b/src/transformers/models/dinat/modeling_dinat.py
@@ -269,7 +269,7 @@ class DinatDownsampler(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py
index 3e49b50f21..a0cf8be82f 100644
--- a/src/transformers/models/dinov2/modeling_dinov2.py
+++ b/src/transformers/models/dinov2/modeling_dinov2.py
@@ -316,7 +316,7 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
-class Dinov2DropPath:
+class Dinov2DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py
index 65c48eb81f..0d833406e2 100644
--- a/src/transformers/models/donut/modeling_donut_swin.py
+++ b/src/transformers/models/donut/modeling_donut_swin.py
@@ -295,8 +295,8 @@ class DonutSwinPatchMerging(nn.Module):
return input_feature
-# Copied from transformers.models.swin.modeling_swin.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/efficientformer/modeling_efficientformer.py b/src/transformers/models/efficientformer/modeling_efficientformer.py
index c3ed5cace8..5f03a5ab74 100644
--- a/src/transformers/models/efficientformer/modeling_efficientformer.py
+++ b/src/transformers/models/efficientformer/modeling_efficientformer.py
@@ -246,7 +246,7 @@ class EfficientFormerConvMlp(nn.Module):
# Copied from transformers.models.convnext.modeling_convnext.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/esm/modeling_tf_esm.py b/src/transformers/models/esm/modeling_tf_esm.py
index c5cdf53c59..3e9223087b 100644
--- a/src/transformers/models/esm/modeling_tf_esm.py
+++ b/src/transformers/models/esm/modeling_tf_esm.py
@@ -667,7 +667,7 @@ class TFEsmEncoder(Layer):
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm
-class TFEsmPooler(Layer):
+class TFEsmPooler(tf.keras.layers.Layer):
def __init__(self, config: EsmConfig, **kwargs):
super().__init__(**kwargs)
diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py
index fc327ad0b3..8d18a8c63f 100644
--- a/src/transformers/models/focalnet/modeling_focalnet.py
+++ b/src/transformers/models/focalnet/modeling_focalnet.py
@@ -286,7 +286,7 @@ class FocalNetPatchEmbeddings(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py
index d9ebb64f66..d2ddef5c41 100755
--- a/src/transformers/models/glpn/modeling_glpn.py
+++ b/src/transformers/models/glpn/modeling_glpn.py
@@ -52,8 +52,8 @@ GLPN_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
-# Copied from transformers.models.segformer.modeling_segformer.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index d0f71d382f..415c6ac0dc 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -272,7 +272,7 @@ class GPTBigCodeMLP(nn.Module):
self.dropout = nn.Dropout(config.resid_pdrop)
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
- def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py
index 110ffdce75..667b92793a 100644
--- a/src/transformers/models/llama/tokenization_llama.py
+++ b/src/transformers/models/llama/tokenization_llama.py
@@ -30,7 +30,8 @@ from ...utils import logging
if TYPE_CHECKING:
- from transformers.pipelines.conversational import Conversation
+ from ...pipelines.conversational import Conversation
+ from ...tokenization_utils_base import TextInput
logger = logging.get_logger(__name__)
@@ -168,7 +169,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
- def tokenize(self, text, **kwargs) -> List[str]:
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
# Replace the SPIECE_UNDERLINE with a space to make sure SPIECE_UNDERLINE is only used at
# the beginning of the text
if not self.legacy:
@@ -176,7 +177,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
return super().tokenize(text, **kwargs)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
- def _tokenize(self, text):
+ def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py
index 7fa708c599..96c0b7df2c 100644
--- a/src/transformers/models/longt5/modeling_flax_longt5.py
+++ b/src/transformers/models/longt5/modeling_flax_longt5.py
@@ -56,7 +56,7 @@ remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py
index c3d89b693a..f197126277 100644
--- a/src/transformers/models/marian/modeling_flax_marian.py
+++ b/src/transformers/models/marian/modeling_flax_marian.py
@@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim):
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
+def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py
index 7016b598e8..357ac9d4aa 100644
--- a/src/transformers/models/maskformer/modeling_maskformer_swin.py
+++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py
@@ -123,7 +123,7 @@ def window_reverse(windows, window_size, height, width):
# Copied from transformers.models.swin.modeling_swin.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py
index 35ed55f5f5..5e34faf408 100644
--- a/src/transformers/models/mgp_str/modeling_mgp_str.py
+++ b/src/transformers/models/mgp_str/modeling_mgp_str.py
@@ -51,7 +51,7 @@ MGP_STR_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py
index 2293661f2b..ecc745b558 100644
--- a/src/transformers/models/nat/modeling_nat.py
+++ b/src/transformers/models/nat/modeling_nat.py
@@ -263,7 +263,7 @@ class NatDownsampler(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py
index 0dccdf129a..684bb40f2d 100644
--- a/src/transformers/models/owlvit/image_processing_owlvit.py
+++ b/src/transformers/models/owlvit/image_processing_owlvit.py
@@ -47,7 +47,6 @@ if is_torch_available():
logger = logging.get_logger(__name__)
-# Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t):
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py
index 2cf67e63f2..f2a9607a6e 100644
--- a/src/transformers/models/owlvit/modeling_owlvit.py
+++ b/src/transformers/models/owlvit/modeling_owlvit.py
@@ -22,7 +22,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
-from torch import nn
+from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
@@ -120,7 +120,7 @@ class OwlViTOutput(ModelOutput):
# Copied from transformers.models.detr.modeling_detr._upcast
-def _upcast(t: torch.Tensor) -> torch.Tensor:
+def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
@@ -129,7 +129,7 @@ def _upcast(t: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.detr.modeling_detr.box_area
-def box_area(boxes: torch.Tensor) -> torch.Tensor:
+def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
@@ -146,7 +146,7 @@ def box_area(boxes: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.detr.modeling_detr.box_iou
-def box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
+def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py
index ddd83709e9..fdf7f019f2 100644
--- a/src/transformers/models/pegasus/modeling_flax_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py
@@ -210,7 +210,7 @@ PEGASUS_DECODE_INPUTS_DOCSTRING = r"""
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
@@ -223,7 +223,7 @@ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_tok
# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions
-def create_sinusoidal_positions(n_pos, dim, dtype):
+def create_sinusoidal_positions(n_pos, dim):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
sentinel = dim // 2 + dim % 2
out = np.zeros_like(position_enc)
@@ -686,9 +686,7 @@ class FlaxPegasusEncoder(nn.Module):
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
- self.embed_positions = create_sinusoidal_positions(
- self.config.max_position_embeddings, embed_dim, dtype=self.dtype
- )
+ self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
@@ -755,9 +753,7 @@ class FlaxPegasusDecoder(nn.Module):
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
- self.embed_positions = create_sinusoidal_positions(
- self.config.max_position_embeddings, embed_dim, dtype=self.dtype
- )
+ self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py
index 688a9239f0..6acc8ec98e 100755
--- a/src/transformers/models/poolformer/modeling_poolformer.py
+++ b/src/transformers/models/poolformer/modeling_poolformer.py
@@ -50,7 +50,7 @@ POOLFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py
index 09f75092a8..2dd452ec1d 100755
--- a/src/transformers/models/pvt/modeling_pvt.py
+++ b/src/transformers/models/pvt/modeling_pvt.py
@@ -55,8 +55,8 @@ PVT_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
-# Copied from transformers.models.convnext.modeling_convnext.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py
index 6701f66f9a..47f42b5e0e 100755
--- a/src/transformers/models/segformer/modeling_segformer.py
+++ b/src/transformers/models/segformer/modeling_segformer.py
@@ -84,8 +84,8 @@ class SegFormerImageClassifierOutput(ImageClassifierOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
-# Copied from transformers.models.convnext.modeling_convnext.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py
index a29ed38fb4..ff72f87506 100644
--- a/src/transformers/models/swiftformer/modeling_swiftformer.py
+++ b/src/transformers/models/swiftformer/modeling_swiftformer.py
@@ -86,7 +86,7 @@ class SwiftFormerPatchEmbedding(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py
index b324cfdcd9..2cf1d33a51 100644
--- a/src/transformers/models/swin/modeling_swin.py
+++ b/src/transformers/models/swin/modeling_swin.py
@@ -380,7 +380,7 @@ class SwinPatchMerging(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py
index cd58b70650..9464981baf 100644
--- a/src/transformers/models/swin2sr/modeling_swin2sr.py
+++ b/src/transformers/models/swin2sr/modeling_swin2sr.py
@@ -105,8 +105,8 @@ def window_reverse(windows, window_size, height, width):
return windows
-# Copied from transformers.models.swin.modeling_swin.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py
index 97b460479d..e05643a635 100644
--- a/src/transformers/models/swinv2/modeling_swinv2.py
+++ b/src/transformers/models/swinv2/modeling_swinv2.py
@@ -242,7 +242,7 @@ def window_reverse(windows, window_size, height, width):
# Copied from transformers.models.swin.modeling_swin.drop_path
-def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py
index bc26ade028..cc74c30c1d 100644
--- a/src/transformers/models/t5/modeling_flax_t5.py
+++ b/src/transformers/models/t5/modeling_flax_t5.py
@@ -56,7 +56,7 @@ remat = nn_partitioning.remat
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
-def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
+def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
index d5836de339..f4392073b9 100644
--- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
+++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
@@ -1603,7 +1603,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
)
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
- def __init__(self, config, target_lang=None):
+ def __init__(self, config, target_lang: Optional[str] = None):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py
index 6053f479aa..45fd5ed4e7 100644
--- a/src/transformers/models/whisper/tokenization_whisper.py
+++ b/src/transformers/models/whisper/tokenization_whisper.py
@@ -15,7 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
-from typing import List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import regex as re
@@ -25,6 +25,10 @@ from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
+if TYPE_CHECKING:
+ from ...pipelines.conversational import Conversation
+
+
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
@@ -697,7 +701,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
return (text, kwargs)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper
- def _build_conversation_input_ids(self, conversation) -> List[int]:
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py
index 4861de6528..689da15000 100644
--- a/src/transformers/models/whisper/tokenization_whisper_fast.py
+++ b/src/transformers/models/whisper/tokenization_whisper_fast.py
@@ -15,7 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
import numpy as np
from tokenizers import pre_tokenizers, processors
@@ -27,6 +27,10 @@ from .english_normalizer import EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
+if TYPE_CHECKING:
+ from ...pipelines.conversational import Conversation
+
+
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
@@ -468,7 +472,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids
- def _build_conversation_input_ids(self, conversation) -> List[int]:
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py
index 44c6706afa..d6f9bf9d81 100644
--- a/src/transformers/models/x_clip/modeling_x_clip.py
+++ b/src/transformers/models/x_clip/modeling_x_clip.py
@@ -360,7 +360,7 @@ class XCLIPEncoderLayer(nn.Module):
# Copied from transformers.models.beit.modeling_beit.drop_path
-def drop_path(input, drop_prob: float = 0.0, training: bool = False):
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
diff --git a/src/transformers/models/xglm/modeling_tf_xglm.py b/src/transformers/models/xglm/modeling_tf_xglm.py
index b18c50b795..873df14a69 100644
--- a/src/transformers/models/xglm/modeling_tf_xglm.py
+++ b/src/transformers/models/xglm/modeling_tf_xglm.py
@@ -135,7 +135,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
-def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py
index e37db77bec..f01dbbf892 100644
--- a/src/transformers/models/yolos/image_processing_yolos.py
+++ b/src/transformers/models/yolos/image_processing_yolos.py
@@ -770,7 +770,7 @@ class YolosImageProcessor(BaseImageProcessor):
return target
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare
- def prepare(self, image, target, return_segmentation_masks=False, masks_path=None):
+ def prepare(self, image, target, return_segmentation_masks=None, masks_path=None):
logger.warning_once(
"The `prepare` method is deprecated and will be removed in a v4.33. "
"Please use `prepare_annotation` instead. Note: the `prepare_annotation` method "
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 65fedf02d8..c27d8c3da9 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -899,6 +899,13 @@ class BartModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
+class BartPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class BartPretrainedModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/tests/repo_utils/test_check_copies.py b/tests/repo_utils/test_check_copies.py
index 57cecf6653..e3e8e47a87 100644
--- a/tests/repo_utils/test_check_copies.py
+++ b/tests/repo_utils/test_check_copies.py
@@ -13,19 +13,19 @@
# limitations under the License.
import os
-import re
import shutil
import sys
import tempfile
import unittest
-
-import black
+from contextlib import contextmanager
+from pathlib import Path
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
import check_copies # noqa: E402
+from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
# This is the reference code that will be used in the tests.
@@ -49,78 +49,137 @@ REFERENCE_CODE = """ def __init__(self, config):
return hidden_states
"""
+MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
+
+def bert_function(x):
+ return x
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+
+class BertModel(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__()
+ self.bert = BertEncoder(config)
+
+ @add_docstring(BERT_DOCSTRING)
+ def forward(self, x):
+ return self.bert(x)
+"""
+
+MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
+
+# Copied from transformers.models.bert.modeling_bert.bert_function
+def bert_copy_function(x):
+ return x
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention
+class BertCopyAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+
+# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
+class BertCopyModel(BertCopyPreTrainedModel):
+ def __init__(self, config):
+ super().__init__()
+ self.bertcopy = BertCopyEncoder(config)
+
+ @add_docstring(BERTCOPY_DOCSTRING)
+ def forward(self, x):
+ return self.bertcopy(x)
+"""
+
+
+def replace_in_file(filename, old, new):
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ content = content.replace(old, new)
+
+ with open(filename, "w", encoding="utf-8") as f:
+ f.write(content)
+
+
+def create_tmp_repo(tmp_dir):
+ """
+ Creates a mock repository in a temporary folder for testing.
+ """
+ tmp_dir = Path(tmp_dir)
+ if tmp_dir.exists():
+ shutil.rmtree(tmp_dir)
+ tmp_dir.mkdir(exist_ok=True)
+
+ model_dir = tmp_dir / "src" / "transformers" / "models"
+ model_dir.mkdir(parents=True, exist_ok=True)
+
+ models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
+ for model, code in models.items():
+ model_subdir = model_dir / model
+ model_subdir.mkdir(exist_ok=True)
+ with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
+ f.write(code)
+
+
+@contextmanager
+def patch_transformer_repo_path(new_folder):
+ """
+ Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
+ """
+ old_repo_path = check_copies.REPO_PATH
+ old_doc_path = check_copies.PATH_TO_DOCS
+ old_transformer_path = check_copies.TRANSFORMERS_PATH
+ repo_path = Path(new_folder).resolve()
+ check_copies.REPO_PATH = str(repo_path)
+ check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
+ check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
+ try:
+ yield
+ finally:
+ check_copies.REPO_PATH = old_repo_path
+ check_copies.PATH_TO_DOCS = old_doc_path
+ check_copies.TRANSFORMERS_PATH = old_transformer_path
+
class CopyCheckTester(unittest.TestCase):
- def setUp(self):
- self.transformer_dir = tempfile.mkdtemp()
- os.makedirs(os.path.join(self.transformer_dir, "models/bert/"))
- check_copies.TRANSFORMER_PATH = self.transformer_dir
- shutil.copy(
- os.path.join(git_repo_path, "src/transformers/models/bert/modeling_bert.py"),
- os.path.join(self.transformer_dir, "models/bert/modeling_bert.py"),
- )
-
- def tearDown(self):
- check_copies.TRANSFORMER_PATH = "src/transformers"
- shutil.rmtree(self.transformer_dir)
-
- def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
- code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
- if overwrite_result is not None:
- expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
- mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
- code = black.format_str(code, mode=mode)
- fname = os.path.join(self.transformer_dir, "new_code.py")
- with open(fname, "w", newline="\n") as f:
- f.write(code)
- if overwrite_result is None:
- self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
- else:
- check_copies.is_copy_consistent(f.name, overwrite=True)
- with open(fname, "r") as f:
- self.assertTrue(f.read(), expected)
-
def test_find_code_in_transformers(self):
- code = check_copies.find_code_in_transformers("models.bert.modeling_bert.BertLMPredictionHead")
- self.assertEqual(code, REFERENCE_CODE)
+ with tempfile.TemporaryDirectory() as tmp_folder:
+ create_tmp_repo(tmp_folder)
+ with patch_transformer_repo_path(tmp_folder):
+ code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
+
+ reference_code = (
+ "class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
+ )
+ self.assertEqual(code, reference_code)
def test_is_copy_consistent(self):
- # Base copy consistency
- self.check_copy_consistency(
- "# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
- "BertLMPredictionHead",
- REFERENCE_CODE + "\n",
- )
+ path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
+ with tempfile.TemporaryDirectory() as tmp_folder:
+ # Base check
+ create_tmp_repo(tmp_folder)
+ with patch_transformer_repo_path(tmp_folder):
+ file_to_check = os.path.join(tmp_folder, *path_to_check)
+ diffs = is_copy_consistent(file_to_check)
+ self.assertEqual(diffs, [])
- # With no empty line at the end
- self.check_copy_consistency(
- "# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead",
- "BertLMPredictionHead",
- REFERENCE_CODE,
- )
+ # Base check with an inconsistency
+ create_tmp_repo(tmp_folder)
+ with patch_transformer_repo_path(tmp_folder):
+ file_to_check = os.path.join(tmp_folder, *path_to_check)
- # Copy consistency with rename
- self.check_copy_consistency(
- "# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
- "TestModelLMPredictionHead",
- re.sub("Bert", "TestModel", REFERENCE_CODE),
- )
+ replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
+ diffs = is_copy_consistent(file_to_check)
+ self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
- # Copy consistency with a really long name
- long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
- self.check_copy_consistency(
- f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
- f"{long_class_name}LMPredictionHead",
- re.sub("Bert", long_class_name, REFERENCE_CODE),
- )
+ diffs = is_copy_consistent(file_to_check, overwrite=True)
- # Copy consistency with overwrite
- self.check_copy_consistency(
- "# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->TestModel",
- "TestModelLMPredictionHead",
- REFERENCE_CODE,
- overwrite_result=re.sub("Bert", "TestModel", REFERENCE_CODE),
- )
+ with open(file_to_check, "r", encoding="utf-8") as f:
+ self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
@@ -168,14 +227,14 @@ class CopyCheckTester(unittest.TestCase):
" Christopher D. Manning 发布。\n"
)
- num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
+ num_models_equal, converted_md_list = convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"]
)
self.assertFalse(num_models_equal)
self.assertEqual(converted_md_list, converted_md_list_sample)
- num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
+ num_models_equal, converted_md_list = convert_to_localized_md(
md_list, converted_md_list, localized_readme["format_model_list"]
)
@@ -201,7 +260,7 @@ class CopyCheckTester(unittest.TestCase):
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
)
- num_models_equal, converted_md_list = check_copies.convert_to_localized_md(
+ num_models_equal, converted_md_list = convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
)
diff --git a/utils/check_copies.py b/utils/check_copies.py
index 959c7b2d32..0352b6419e 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -12,6 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Utility that checks whether the copies defined in the library match the original or not. This includes:
+- All code commented with `# Copied from` comments,
+- The list of models in the main README.md matches the ones in the localized READMEs and in the index.md,
+- Files that are registered as full copies of one another in the `FULL_COPIES` constant of this script.
+
+This also checks the list of models in the README is complete (has all models) and add a line to complete if there is
+a model missing.
+
+Use from the root of the repo with:
+
+```bash
+python utils/check_copies.py
+```
+
+for a check that will error in case of inconsistencies (used by `make repo-consistency`) or
+
+```bash
+python utils/check_copies.py --fix_and_overwrite
+```
+
+for a check that will fix all inconsistencies automatically (used by `make fix-copies`).
+"""
import argparse
import glob
@@ -103,7 +126,9 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
def _should_continue(line, indent):
- return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
+ # Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
+ # function definition
+ return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name):
@@ -140,7 +165,7 @@ def find_code_in_transformers(object_name):
raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
- start_index = line_index
+ start_index = line_index - 1
while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1
# Clean up empty lines at the end (if any).
@@ -179,6 +204,33 @@ def blackify(code):
return result[len("class Bla:\n") :] if has_indent else result
+def check_codes_match(observed_code, theoretical_code):
+ """
+ Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
+ Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
+ """
+ observed_code_header = observed_code.split("\n")[0]
+ theoretical_code_header = theoretical_code.split("\n")[0]
+
+ _re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
+ _re_func_match = re.compile(r"def\s+([^\(]+)\(")
+ for re_pattern in [_re_class_match, _re_func_match]:
+ if re_pattern.match(observed_code_header) is not None:
+ observed_obj_name = re_pattern.search(observed_code_header).groups()[0]
+ theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
+ theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
+
+ diff_index = 0
+ if theoretical_code_header != observed_code_header:
+ return 0
+
+ diff_index = 1
+ for observed_line, theoretical_line in zip(observed_code.split("\n")[1:], theoretical_code.split("\n")[1:]):
+ if observed_line != theoretical_line:
+ return diff_index
+ diff_index += 1
+
+
def is_copy_consistent(filename, overwrite=False):
"""
Check if the code commented as a copy in `filename` matches the original.
@@ -201,10 +253,11 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = find_code_in_transformers(object_name)
theoretical_indent = get_indent(theoretical_code)
- start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
- indent = theoretical_indent
- line_index = start_index
+ start_index = line_index + 1 if indent == theoretical_indent else line_index
+ line_index = start_index + 1
+ subcode = "\n".join(theoretical_code.split("\n")[1:])
+ indent = get_indent(subcode)
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
should_continue = True
while line_index < len(lines) and should_continue:
@@ -212,6 +265,8 @@ def is_copy_consistent(filename, overwrite=False):
if line_index >= len(lines):
break
line = lines[line_index]
+ # There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
+ # used.
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
@@ -233,19 +288,12 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
- # Blackify after replacement. To be able to do that, we need the header (class or function definition)
- # from the previous line
- theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
- theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
+ theoretical_code = blackify(theoretical_code)
# Test for a diff and act accordingly.
- if observed_code != theoretical_code:
- diff_index = start_index + 1
- for observed_line, theoretical_line in zip(observed_code.split("\n"), theoretical_code.split("\n")):
- if observed_line != theoretical_line:
- break
- diff_index += 1
- diffs.append([object_name, diff_index])
+ diff_index = check_codes_match(observed_code, theoretical_code)
+ if diff_index is not None:
+ diffs.append([object_name, diff_index + start_index + 1])
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1
@@ -259,6 +307,10 @@ def is_copy_consistent(filename, overwrite=False):
def check_copies(overwrite: bool = False):
+ """
+ Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
+ model list in the main README and other READMEs/index.md are consistent.
+ """
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
for filename in all_files:
@@ -275,6 +327,10 @@ def check_copies(overwrite: bool = False):
def check_full_copies(overwrite: bool = False):
+ """
+ Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
+ `overwrite` to fix issues.
+ """
diffs = []
for target, source in FULL_COPIES.items():
with open(source, "r", encoding="utf-8") as f:
@@ -299,7 +355,7 @@ def check_full_copies(overwrite: bool = False):
def get_model_list(filename, start_prompt, end_prompt):
- """Extracts the model list from the README."""
+ """Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
@@ -327,7 +383,20 @@ def get_model_list(filename, start_prompt, end_prompt):
def convert_to_localized_md(model_list, localized_model_list, format_str):
- """Convert `model_list` to each localized README."""
+ """
+ Compare the model list from the main README to the one in a localized README.
+
+ Args:
+ model_list (`str`): The model list in the main README.
+ localized_model_list (`str`): The model list in one of the localized README.
+ format_str (`str`):
+ The template for a model entry in the localized README (look at the `format_model_list` in the entries of
+ `LOCALIZED_READMES` for examples).
+
+ Returns:
+ `Tuple[bool, str]`: A tuple where the first value indicates if the READMEs match or not, and the second value
+ is the correct localized README.
+ """
def _rep(match):
title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
@@ -341,7 +410,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
)
# This regex captures metadata from an English model description, including model title, model link,
- # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for example).
+ # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for
+ # example).
_re_capture_meta = re.compile(
r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
)
@@ -389,6 +459,10 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
def convert_readme_to_index(model_list):
+ """
+ Converts the model list of the README to the index.md format.
+ """
+ # We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "")
@@ -420,7 +494,9 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
def check_model_list_copy(overwrite=False, max_per_line=119):
- """Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
+ """
+ Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
+ """
# Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
readme = f.read()
@@ -490,6 +566,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)
+# Map a model name with the name it has in the README for the check_readme check
SPECIAL_MODEL_NAMES = {
"Bert Generation": "BERT For Sequence Generation",
"BigBird": "BigBird-RoBERTa",
@@ -522,7 +599,7 @@ MODELS_NOT_IN_README = [
"VisionTextDualEncoder",
]
-
+# Template for new entries to add in the main README when we have missing models.
README_TEMPLATE = (
"1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from "
") released with the paper []() by ."
@@ -530,6 +607,10 @@ README_TEMPLATE = (
def check_readme(overwrite=False):
+ """
+ Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
+ missing models using `README_TEMPLATE`.
+ """
info = LOCALIZED_READMES["README.md"]
models, start_index, end_index, lines = _find_text_in_file(
os.path.join(REPO_PATH, "README.md"),