@@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutput
|
|||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
can_return_tuple,
|
can_return_tuple,
|
||||||
@@ -44,8 +43,6 @@ from .configuration_timesfm import TimesFmConfig
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
|
|
||||||
_CONFIG_FOR_DOC = "TimesFmConfig"
|
_CONFIG_FOR_DOC = "TimesFmConfig"
|
||||||
|
|
||||||
|
|
||||||
@@ -734,11 +731,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
|||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
||||||
@add_code_sample_docstrings(
|
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=TimesFmOutputForPrediction,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
past_values: Sequence[torch.Tensor],
|
past_values: Sequence[torch.Tensor],
|
||||||
@@ -769,11 +761,23 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
|||||||
Whether to output the hidden states.
|
Whether to output the hidden states.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A TimesFmOutputForPrediction object or a tuple containing:
|
|
||||||
- the mean forecast of size (# past_values, # forecast horizon),
|
Example:
|
||||||
- the full forecast (mean + quantiles) of size
|
|
||||||
(# past_values, # forecast horizon, 1 + # quantiles).
|
```python
|
||||||
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
|
>>> from transformers import TimesFmModelForPrediction
|
||||||
|
|
||||||
|
>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
|
||||||
|
|
||||||
|
>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
|
||||||
|
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
>>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
|
||||||
|
>>> point_forecast_conv = outputs.mean_predictions
|
||||||
|
>>> quantile_forecast_conv = outputs.full_predictions
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if forecast_context_len is None:
|
if forecast_context_len is None:
|
||||||
fcontext_len = self.context_len
|
fcontext_len = self.context_len
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from ...modeling_outputs import BaseModelOutput
|
|||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_code_sample_docstrings,
|
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
can_return_tuple,
|
can_return_tuple,
|
||||||
@@ -690,11 +689,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
|||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
||||||
@add_code_sample_docstrings(
|
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
|
||||||
output_type=TimesFmOutputForPrediction,
|
|
||||||
config_class=_CONFIG_FOR_DOC,
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
past_values: Sequence[torch.Tensor],
|
past_values: Sequence[torch.Tensor],
|
||||||
@@ -725,11 +719,23 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
|||||||
Whether to output the hidden states.
|
Whether to output the hidden states.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A TimesFmOutputForPrediction object or a tuple containing:
|
|
||||||
- the mean forecast of size (# past_values, # forecast horizon),
|
Example:
|
||||||
- the full forecast (mean + quantiles) of size
|
|
||||||
(# past_values, # forecast horizon, 1 + # quantiles).
|
```python
|
||||||
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
|
>>> from transformers import TimesFmModelForPrediction
|
||||||
|
|
||||||
|
>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
|
||||||
|
|
||||||
|
>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
|
||||||
|
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> with torch.no_grad():
|
||||||
|
>>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
|
||||||
|
>>> point_forecast_conv = outputs.mean_predictions
|
||||||
|
>>> quantile_forecast_conv = outputs.full_predictions
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
if forecast_context_len is None:
|
if forecast_context_len is None:
|
||||||
fcontext_len = self.context_len
|
fcontext_len = self.context_len
|
||||||
|
|||||||
Reference in New Issue
Block a user