From 4f0337a08fb3eaa7582061014b117747802f75a3 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Fri, 14 Oct 2022 15:57:03 +0200 Subject: [PATCH] [Time Series Transformer] Add doc tests (#19607) * Add doc tests * Make it more consistent Co-authored-by: Niels Rogge --- .../modeling_time_series_transformer.py | 97 +++++++++++-------- utils/documentation_tests.txt | 1 + 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index bf39ae1756..df7c844fe2 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1614,33 +1614,30 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): Examples: ```python - >>> from transformers import TimeSeriesTransformerModel + >>> from huggingface_hub import hf_hub_download >>> import torch + >>> from transformers import TimeSeriesTransformerModel - >>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/tst-base") + >>> file = hf_hub_download( + ... repo_id="kashif/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) - >>> inputs = dict() - >>> batch_size = 2 - >>> cardinality = 5 - >>> num_time_features = 10 - >>> content_length = 8 - >>> prediction_length = 2 - >>> lags_sequence = [2, 3] - >>> past_length = context_length + max(lags_sequence) + >>> model = TimeSeriesTransformerModel.from_pretrained("huggingface/time-series-transformer-tourism-monthly") - >>> # encoder inputs - >>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality) - >>> inputs["static_real_features"] = torch.randn([batch_size, 1]) - >>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features]) - >>> inputs["past_values"] = torch.randn([batch_size, past_length]) - >>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length]) + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) - >>> # decoder inputs - >>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features]) - >>> inputs["future_values"] = torch.randn([batch_size, prediction_length]) - - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state + >>> last_hidden_state = outputs.last_hidden_state ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1789,33 +1786,47 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel): Examples: ```python - >>> from transformers import TimeSeriesTransformerForPrediction + >>> from huggingface_hub import hf_hub_download >>> import torch + >>> from transformers import TimeSeriesTransformerForPrediction - >>> model = TimeSeriesTransformerForPrediction.from_pretrained("huggingface/tst-base") + >>> file = hf_hub_download( + ... repo_id="kashif/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) - >>> inputs = dict() - >>> batch_size = 2 - >>> cardinality = 5 - >>> num_time_features = 10 - >>> content_length = 8 - >>> prediction_length = 2 - >>> lags_sequence = [2, 3] - >>> past_length = context_length + max(lags_sequence) + >>> model = TimeSeriesTransformerForPrediction.from_pretrained( + ... "huggingface/time-series-transformer-tourism-monthly" + ... ) - >>> # encoder inputs - >>> inputs["static_categorical_features"] = ids_tensor([batch_size, 1], cardinality) - >>> inputs["static_real_features"] = torch.randn([batch_size, 1]) - >>> inputs["past_time_features"] = torch.randn([batch_size, past_length, num_time_features]) - >>> inputs["past_values"] = torch.randn([batch_size, past_length]) - >>> inputs["past_observed_mask"] = torch.ones([batch_size, past_length]) + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) - >>> # decoder inputs - >>> inputs["future_time_features"] = torch.randn([batch_size, prediction_length, num_time_features]) - >>> inputs["future_values"] = torch.randn([batch_size, prediction_length]) - - >>> outputs = model(**inputs) >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index c53f9ac4f0..cb82c0024a 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -104,6 +104,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py src/transformers/models/swin/configuration_swin.py src/transformers/models/swin/modeling_swin.py src/transformers/models/swinv2/configuration_swinv2.py +src/transformers/models/time_series_transformer/modeling_time_series_transformer.py src/transformers/models/trajectory_transformer/configuration_trajectory_transformer.py src/transformers/models/trocr/modeling_trocr.py src/transformers/models/unispeech/configuration_unispeech.py