diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 8a12b2c56c..e34f823500 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( - add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -44,8 +43,6 @@ from .configuration_timesfm import TimesFmConfig logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch" _CONFIG_FOR_DOC = "TimesFmConfig" @@ -734,11 +731,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel): @can_return_tuple @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TimesFmOutputForPrediction, - config_class=_CONFIG_FOR_DOC, - ) def forward( self, past_values: Sequence[torch.Tensor], @@ -752,28 +744,40 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel): output_hidden_states: Optional[bool] = None, ) -> TimesFmOutputForPrediction: r""" - window_size (`int`, *optional*): - Window size of trend + residual decomposition. If None then we do not do decomposition. - future_values (`torch.Tensor`, *optional*): - Optional future time series values to be used for loss computation. - forecast_context_len (`int`, *optional*): - Optional max context length. - return_forecast_on_context (`bool`, *optional*): - True to return the forecast on the context when available, i.e. after the first input patch. - truncate_negative (`bool`, *optional*): - Truncate to only non-negative values if any of the contexts have non-negative values, - otherwise do nothing. - output_attentions (`bool`, *optional*): - Whether to output the attentions. - output_hidden_states (`bool`, *optional*): - Whether to output the hidden states. + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`torch.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. Returns: - A TimesFmOutputForPrediction object or a tuple containing: - - the mean forecast of size (# past_values, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# past_values, # forecast horizon, 1 + # quantiles). - - loss: the mean squared error loss + quantile loss if `future_values` is provided. + + Example: + + ```python + >>> from transformers import TimesFmModelForPrediction + + >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") + + >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()] + >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long) + + >>> # Generate + >>> with torch.no_grad(): + >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True) + >>> point_forecast_conv = outputs.mean_predictions + >>> quantile_forecast_conv = outputs.full_predictions + ``` """ if forecast_context_len is None: fcontext_len = self.context_len diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index 4a62752484..e285a4a19e 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -27,7 +27,6 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( - add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, @@ -690,11 +689,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel): @can_return_tuple @add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TimesFmOutputForPrediction, - config_class=_CONFIG_FOR_DOC, - ) def forward( self, past_values: Sequence[torch.Tensor], @@ -708,28 +702,40 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel): output_hidden_states: Optional[bool] = None, ) -> TimesFmOutputForPrediction: r""" - window_size (`int`, *optional*): - Window size of trend + residual decomposition. If None then we do not do decomposition. - future_values (`torch.Tensor`, *optional*): - Optional future time series values to be used for loss computation. - forecast_context_len (`int`, *optional*): - Optional max context length. - return_forecast_on_context (`bool`, *optional*): - True to return the forecast on the context when available, i.e. after the first input patch. - truncate_negative (`bool`, *optional*): - Truncate to only non-negative values if any of the contexts have non-negative values, - otherwise do nothing. - output_attentions (`bool`, *optional*): - Whether to output the attentions. - output_hidden_states (`bool`, *optional*): - Whether to output the hidden states. + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`torch.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. Returns: - A TimesFmOutputForPrediction object or a tuple containing: - - the mean forecast of size (# past_values, # forecast horizon), - - the full forecast (mean + quantiles) of size - (# past_values, # forecast horizon, 1 + # quantiles). - - loss: the mean squared error loss + quantile loss if `future_values` is provided. + + Example: + + ```python + >>> from transformers import TimesFmModelForPrediction + + >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") + + >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()] + >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long) + + >>> # Generate + >>> with torch.no_grad(): + >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True) + >>> point_forecast_conv = outputs.mean_predictions + >>> quantile_forecast_conv = outputs.full_predictions + ``` """ if forecast_context_len is None: fcontext_len = self.context_len