[Time Series Transformer] Add doc tests (#19607)
* Add doc tests * Make it more consistent Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -1614,33 +1614,30 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import TimeSeriesTransformerModel
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
>>> import torch
|
||||
>>> from transformers import TimeSeriesTransformerModel
|
||||
|
||||
>>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/tst-base")
|
||||
>>> file = hf_hub_download(
|
||||
... repo_id="kashif/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
|
||||
... )
|
||||
>>> batch = torch.load(file)
|
||||
|
||||
>>> inputs = dict()
|
||||
>>> batch_size = 2
|
||||
>>> cardinality = 5
|
||||
>>> num_time_features = 10
|
||||
>>> content_length = 8
|
||||
>>> prediction_length = 2
|
||||
>>> lags_sequence = [2, 3]
|
||||
>>> past_length = context_length + max(lags_sequence)
|
||||
>>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly")
|
||||
|
||||
>>> # encoder inputs
|
||||
>>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality)
|
||||
>>> inputs["static_real_features"] = torch.randn([batch_size, 1])
|
||||
>>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features])
|
||||
>>> inputs["past_values"] = torch.randn([batch_size, past_length])
|
||||
>>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length])
|
||||
>>> # during training, one provides both past and future values
|
||||
>>> # as well as possible additional features
|
||||
>>> outputs = model(
|
||||
... past_values=batch["past_values"],
|
||||
... past_time_features=batch["past_time_features"],
|
||||
... past_observed_mask=batch["past_observed_mask"],
|
||||
... static_categorical_features=batch["static_categorical_features"],
|
||||
... static_real_features=batch["static_real_features"],
|
||||
... future_values=batch["future_values"],
|
||||
... future_time_features=batch["future_time_features"],
|
||||
... )
|
||||
|
||||
>>> # decoder inputs
|
||||
>>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features])
|
||||
>>> inputs["future_values"] = torch.randn([batch_size, prediction_length])
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -1789,33 +1786,47 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import TimeSeriesTransformerForPrediction
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
>>> import torch
|
||||
>>> from transformers import TimeSeriesTransformerForPrediction
|
||||
|
||||
>>> model = TimeSeriesTransformerForPrediction.from_pretrained("huggingface/tst-base")
|
||||
>>> file = hf_hub_download(
|
||||
... repo_id="kashif/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
|
||||
... )
|
||||
>>> batch = torch.load(file)
|
||||
|
||||
>>> inputs = dict()
|
||||
>>> batch_size = 2
|
||||
>>> cardinality = 5
|
||||
>>> num_time_features = 10
|
||||
>>> content_length = 8
|
||||
>>> prediction_length = 2
|
||||
>>> lags_sequence = [2, 3]
|
||||
>>> past_length = context_length + max(lags_sequence)
|
||||
>>> model = TimeSeriesTransformerForPrediction.from_pretrained(
|
||||
... "huggingface/time-series-transformer-tourism-monthly"
|
||||
... )
|
||||
|
||||
>>> # encoder inputs
|
||||
>>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality)
|
||||
>>> inputs["static_real_features"] = torch.randn([batch_size, 1])
|
||||
>>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features])
|
||||
>>> inputs["past_values"] = torch.randn([batch_size, past_length])
|
||||
>>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length])
|
||||
>>> # during training, one provides both past and future values
|
||||
>>> # as well as possible additional features
|
||||
>>> outputs = model(
|
||||
... past_values=batch["past_values"],
|
||||
... past_time_features=batch["past_time_features"],
|
||||
... past_observed_mask=batch["past_observed_mask"],
|
||||
... static_categorical_features=batch["static_categorical_features"],
|
||||
... static_real_features=batch["static_real_features"],
|
||||
... future_values=batch["future_values"],
|
||||
... future_time_features=batch["future_time_features"],
|
||||
... )
|
||||
|
||||
>>> # decoder inputs
|
||||
>>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features])
|
||||
>>> inputs["future_values"] = torch.randn([batch_size, prediction_length])
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> loss = outputs.loss
|
||||
>>> loss.backward()
|
||||
|
||||
>>> # during inference, one only provides past values
|
||||
>>> # as well as possible additional features
|
||||
>>> # the model autoregressively generates future values
|
||||
>>> outputs = model.generate(
|
||||
... past_values=batch["past_values"],
|
||||
... past_time_features=batch["past_time_features"],
|
||||
... past_observed_mask=batch["past_observed_mask"],
|
||||
... static_categorical_features=batch["static_categorical_features"],
|
||||
... static_real_features=batch["static_real_features"],
|
||||
... future_time_features=batch["future_time_features"],
|
||||
... )
|
||||
|
||||
>>> mean_prediction = outputs.sequences.mean(dim=1)
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -104,6 +104,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py
|
||||
src/transformers/models/swin/configuration_swin.py
|
||||
src/transformers/models/swin/modeling_swin.py
|
||||
src/transformers/models/swinv2/configuration_swinv2.py
|
||||
src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
|
||||
src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py
|
||||
src/transformers/models/trocr/modeling_trocr.py
|
||||
src/transformers/models/unispeech/configuration_unispeech.py
|
||||
|
||||
Reference in New Issue
Block a user