Fix and improve CTRL doctests (#16573)
* Improve CTRL doctests * Fix `CTRLForSequenceClassification` flakiness with inconsistent losses * Remove unused * Fixup * Add CTRL to documentation_tests.txt * Fix control code not being first * Add output assertions * Change from sshleifer/tiny-ctrl -> ctrl * Run `make fixup` * apply `list` to output logits shape for clarity * Reduce output loss precision to make assertion more robust * Add assertion of control code being first * Fix docstyle * upper case sentence following control code * Weird bug fixes * Add a better generation example Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -25,15 +25,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "ctrl"
|
|
||||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||||
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
|
|
||||||
|
|
||||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"ctrl"
|
"ctrl"
|
||||||
@@ -352,12 +350,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
self.h[layer].multi_head_attention.prune_heads(heads)
|
self.h[layer].multi_head_attention.prune_heads(heads)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@replace_return_docstrings(output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=BaseModelOutputWithPast,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -372,7 +365,28 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import CTRLTokenizer, CTRLModel
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> tokenizer = CTRLTokenizer.from_pretrained("ctrl")
|
||||||
|
>>> model = CTRLModel.from_pretrained("ctrl")
|
||||||
|
|
||||||
|
>>> # CTRL was trained with control codes as the first token
|
||||||
|
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||||
|
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
>>> list(last_hidden_states.shape)
|
||||||
|
[1, 5, 1280]
|
||||||
|
```"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -515,12 +529,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
|
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=CausalLMOutputWithPast,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -541,7 +550,34 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
"""
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import CTRLTokenizer, CTRLLMHeadModel
|
||||||
|
|
||||||
|
>>> tokenizer = CTRLTokenizer.from_pretrained("ctrl")
|
||||||
|
>>> model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||||
|
|
||||||
|
>>> # CTRL was trained with control codes as the first token
|
||||||
|
>>> inputs = tokenizer("Wikipedia The llama is", return_tensors="pt")
|
||||||
|
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
||||||
|
|
||||||
|
>>> sequence_ids = model.generate(inputs["input_ids"])
|
||||||
|
>>> sequences = tokenizer.batch_decode(sequence_ids)
|
||||||
|
>>> sequences
|
||||||
|
['Wikipedia The llama is a member of the family Bovidae. It is native to the Andes of Peru,']
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
||||||
|
>>> round(outputs.loss.item(), 2)
|
||||||
|
9.21
|
||||||
|
|
||||||
|
>>> list(outputs.logits.shape)
|
||||||
|
[1, 5, 246534]
|
||||||
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
@@ -619,12 +655,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
processor_class=_TOKENIZER_FOR_DOC,
|
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=SequenceClassifierOutput,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -645,7 +676,77 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
"""
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example of single-label classification:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import CTRLTokenizer, CTRLForSequenceClassification
|
||||||
|
|
||||||
|
>>> tokenizer = CTRLTokenizer.from_pretrained("ctrl")
|
||||||
|
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl")
|
||||||
|
|
||||||
|
>>> # CTRL was trained with control codes as the first token
|
||||||
|
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||||
|
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
||||||
|
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
|
>>> predicted_class_id = logits.argmax().item()
|
||||||
|
>>> model.config.id2label[predicted_class_id]
|
||||||
|
'LABEL_0'
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> torch.manual_seed(42) # doctest: +IGNORE_RESULT
|
||||||
|
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||||
|
>>> num_labels = len(model.config.id2label)
|
||||||
|
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
|
||||||
|
|
||||||
|
>>> labels = torch.tensor(1)
|
||||||
|
>>> loss = model(**inputs, labels=labels).loss
|
||||||
|
>>> round(loss.item(), 2)
|
||||||
|
0.35
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of multi-label classification:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import CTRLTokenizer, CTRLForSequenceClassification
|
||||||
|
|
||||||
|
>>> tokenizer = CTRLTokenizer.from_pretrained("ctrl")
|
||||||
|
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", problem_type="multi_label_classification")
|
||||||
|
|
||||||
|
>>> # CTRL was trained with control codes as the first token
|
||||||
|
>>> inputs = tokenizer("Opinion My dog is cute", return_tensors="pt")
|
||||||
|
>>> assert inputs["input_ids"][0, 0].item() in tokenizer.control_codes.values()
|
||||||
|
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
... logits = model(**inputs).logits
|
||||||
|
|
||||||
|
>>> predicted_class_id = logits.argmax().item()
|
||||||
|
>>> model.config.id2label[predicted_class_id]
|
||||||
|
'LABEL_0'
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||||
|
>>> num_labels = len(model.config.id2label)
|
||||||
|
>>> model = CTRLForSequenceClassification.from_pretrained("ctrl", num_labels=num_labels)
|
||||||
|
|
||||||
|
>>> num_labels = len(model.config.id2label)
|
||||||
|
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
||||||
|
... torch.float
|
||||||
|
... )
|
||||||
|
>>> loss = model(**inputs, labels=labels).loss
|
||||||
|
>>> loss.backward() # doctest: +IGNORE_RESULT
|
||||||
|
```"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
|||||||
@@ -52,3 +52,4 @@ src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
|||||||
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
||||||
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
|
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
|
||||||
src/transformers/models/wavlm/modeling_wavlm.py
|
src/transformers/models/wavlm/modeling_wavlm.py
|
||||||
|
src/transformers/models/ctrl/modeling_ctrl.py
|
||||||
|
|||||||
Reference in New Issue
Block a user