[Doctest] added doctest changes for electra (#16675)
* added doctest changes for electra * fixed doctest tests * updated changes
This commit is contained in:
@@ -967,9 +967,11 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="bhadresh-savani/electra-base-emotion",
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output="'joy'",
|
||||||
|
expected_loss=0.06,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1087,16 +1089,25 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import ElectraTokenizer, ElectraForPreTraining
|
>>> from transformers import ElectraForPreTraining, ElectraTokenizerFast
|
||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = ElectraTokenizer.from_pretrained("google/electra-small-discriminator")
|
>>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
|
||||||
>>> model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
|
>>> tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-base-discriminator")
|
||||||
|
|
||||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
|
>>> sentence = "The quick brown fox jumps over the lazy dog"
|
||||||
... 0
|
>>> fake_sentence = "The quick brown fox fake over the lazy dog"
|
||||||
>>> ) # Batch size 1
|
|
||||||
>>> logits = model(input_ids).logits
|
>>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
|
||||||
|
>>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
|
||||||
|
>>> discriminator_outputs = discriminator(fake_inputs)
|
||||||
|
>>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
|
||||||
|
|
||||||
|
>>> fake_tokens
|
||||||
|
['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']
|
||||||
|
|
||||||
|
>>> predictions.squeeze().tolist()
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@@ -1167,9 +1178,12 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="google/electra-small-generator",
|
||||||
output_type=MaskedLMOutput,
|
output_type=MaskedLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
mask="[MASK]",
|
||||||
|
expected_output="'paris'",
|
||||||
|
expected_loss=1.22,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1251,9 +1265,11 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
|
||||||
output_type=TokenClassifierOutput,
|
output_type=TokenClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
|
||||||
|
expected_loss=0.11,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1331,9 +1347,13 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="bhadresh-savani/electra-base-squad2",
|
||||||
output_type=QuestionAnsweringModelOutput,
|
output_type=QuestionAnsweringModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
qa_target_start_index=11,
|
||||||
|
qa_target_end_index=12,
|
||||||
|
expected_output="'a nice puppet'",
|
||||||
|
expected_loss=2.64,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1160,9 +1160,12 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="google/electra-small-generator",
|
||||||
output_type=TFMaskedLMOutput,
|
output_type=TFMaskedLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
mask="[MASK]",
|
||||||
|
expected_output="'paris'",
|
||||||
|
expected_loss=1.22,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -1269,9 +1272,11 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="bhadresh-savani/electra-base-emotion",
|
||||||
output_type=TFSequenceClassifierOutput,
|
output_type=TFSequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output="'joy'",
|
||||||
|
expected_loss=0.06,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -1478,9 +1483,11 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
|
||||||
output_type=TFTokenClassifierOutput,
|
output_type=TFTokenClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
|
||||||
|
expected_loss=0.11,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -1558,9 +1565,13 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
|
|||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
processor_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint="bhadresh-savani/electra-base-squad2",
|
||||||
output_type=TFQuestionAnsweringModelOutput,
|
output_type=TFQuestionAnsweringModelOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
qa_target_start_index=11,
|
||||||
|
qa_target_end_index=12,
|
||||||
|
expected_output="'a nice puppet'",
|
||||||
|
expected_loss=2.64,
|
||||||
)
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ src/transformers/models/convnext/modeling_convnext.py
|
|||||||
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
||||||
src/transformers/models/deit/modeling_deit.py
|
src/transformers/models/deit/modeling_deit.py
|
||||||
src/transformers/models/dpt/modeling_dpt.py
|
src/transformers/models/dpt/modeling_dpt.py
|
||||||
|
src/transformers/models/electra/modeling_electra.py
|
||||||
|
src/transformers/models/electra/modeling_tf_electra.py
|
||||||
src/transformers/models/glpn/modeling_glpn.py
|
src/transformers/models/glpn/modeling_glpn.py
|
||||||
src/transformers/models/gpt2/modeling_gpt2.py
|
src/transformers/models/gpt2/modeling_gpt2.py
|
||||||
src/transformers/models/gptj/modeling_gptj.py
|
src/transformers/models/gptj/modeling_gptj.py
|
||||||
|
|||||||
Reference in New Issue
Block a user