PatchtTST and PatchTSMixer fixes (#28083)

* 🐛 fix .max bug

* remove prediction_length from regression output dimensions

* fix parameter names, fix output names, update tests

* ensure shape for PatchTST

* ensure output shape for PatchTSMixer

* update model, batch, and expected for regression distribution test

* update test expected

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtst/test_modeling_patchtst.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* standardize on patch_length

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Make arguments more explicit

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

* adjust prepared inputs

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

---------

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: Wesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Wesley Gifford
2024-01-29 05:09:26 -05:00
committed by GitHub
parent 3a08cc485f
commit f72c7c22d9
5 changed files with 115 additions and 89 deletions

View File

@@ -40,7 +40,7 @@ class PatchTSMixerConfig(PretrainedConfig):
Args:
context_length (`int`, *optional*, defaults to 32):
The context/history length for the input sequence.
patch_len (`int`, *optional*, defaults to 8):
patch_length (`int`, *optional*, defaults to 8):
The patch length for the input sequence.
num_input_channels (`int`, *optional*, defaults to 1):
Number of input variates. For Univariate, set it to 1.
@@ -51,7 +51,7 @@ class PatchTSMixerConfig(PretrainedConfig):
The number of samples to generate in parallel for probabilistic forecast.
d_model (`int`, *optional*, defaults to 8):
Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of
patch_len). Larger value indicates more complex model.
patch_length). Larger value indicates more complex model.
expansion_factor (`int`, *optional*, defaults to 2):
Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
num_layers (`int`, *optional*, defaults to 3):
@@ -155,7 +155,7 @@ class PatchTSMixerConfig(PretrainedConfig):
self,
# Time series specific configuration
context_length: int = 32,
patch_len: int = 8,
patch_length: int = 8,
num_input_channels: int = 1,
patch_stride: int = 8,
num_parallel_samples: int = 100,
@@ -198,7 +198,7 @@ class PatchTSMixerConfig(PretrainedConfig):
):
self.num_input_channels = num_input_channels
self.context_length = context_length
self.patch_length = patch_len
self.patch_length = patch_length
self.patch_stride = patch_stride
self.d_model = d_model
self.expansion_factor = expansion_factor
@@ -209,7 +209,7 @@ class PatchTSMixerConfig(PretrainedConfig):
self.norm_mlp = norm_mlp
self.scaling = scaling
self.head_dropout = head_dropout
self.num_patches = (max(context_length, patch_len) - patch_len) // patch_stride + 1
self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1
self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio
self.num_forecast_mask_patches = num_forecast_mask_patches

View File

@@ -888,7 +888,7 @@ def forecast_masking(
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)`
Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
@@ -1864,15 +1864,15 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
def forward(
self,
past_values: torch.Tensor,
future_values: torch.Tensor = None,
target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForTimeSeriesClassificationOutput:
r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
@@ -1912,8 +1912,8 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
if future_values is not None and return_loss is True:
loss_val = loss(y_hat, future_values)
if target_values is not None and return_loss is True:
loss_val = loss(y_hat, target_values)
else:
loss_val = None
@@ -1942,7 +1942,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput):
Output type of [`PatchTSMixerForRegressionOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Prediction output from the regression head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
@@ -1953,7 +1953,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput):
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
regression_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@@ -2054,15 +2054,15 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
def forward(
self,
past_values: torch.Tensor,
future_values: torch.Tensor = None,
target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForRegressionOutput:
r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
@@ -2106,16 +2106,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
if future_values is not None and return_loss is True:
if target_values is not None and return_loss is True:
if self.distribution_output:
if self.distribution_output == "negative_binomial" and torch.any(future_values < 0):
raise Exception("future_values cannot be negative for negative_binomial distribution.")
if self.distribution_output == "negative_binomial" and torch.any(target_values < 0):
raise Exception("target_values cannot be negative for negative_binomial distribution.")
distribution = self.distribution_output.distribution(y_hat)
loss_val = loss(distribution, future_values)
# y_hat should be a 2-tuple, each with dimension [bs, num_targets]
y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
loss_val = loss(distribution, target_values)
# take average of the loss
loss_val = weighted_average(loss_val)
else:
loss_val = loss(y_hat, future_values)
loss_val = loss(y_hat, target_values)
else:
loss_val = None
@@ -2132,7 +2134,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
return PatchTSMixerForRegressionOutput(
loss=loss_val,
prediction_outputs=y_hat, # tensor [batch_size x num_targets]
regression_outputs=y_hat, # tensor [batch_size x num_targets]
last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states,
)
@@ -2146,7 +2148,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future.
Past values of the time series that serves as context in order to predict the target values.
Return:
[`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
@@ -2158,17 +2160,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
# get model output
outputs = self(
past_values=past_values,
future_values=None,
target_values=None,
output_hidden_states=False,
)
# get distribution
distribution = self.distribution_output.distribution(outputs.prediction_outputs)
distribution = self.distribution_output.distribution(outputs.regression_outputs)
# get samples
samples = [
distribution.sample() for _ in range(num_parallel_samples)
] # samples: list of [batch_size x num_targets]
# stack tensors
samples = torch.stack(samples, dim=1) # [batch_size x num_samples x num_targets]
# [batch_size x num_samples x num_targets]
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSMixerRegressionOutput(sequences=samples)

View File

@@ -289,7 +289,7 @@ def forecast_masking(
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)`
Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
@@ -1430,7 +1430,7 @@ class PatchTSTClassificationHead(nn.Module):
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2)
pooled_embedding = embedding.max(dim=2).values
else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
# pooled_embedding: bs x num_channels * d_model
@@ -1602,7 +1602,7 @@ class PatchTSTPredictionHead(nn.Module):
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2)
pooled_embedding = embedding.max(dim=2).values
else:
# pooled_embedding: [bs x num_channels x num_patches x d_model]
pooled_embedding = embedding
@@ -1866,7 +1866,7 @@ class PatchTSTRegressionHead(nn.Module):
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2)
pooled_embedding = embedding.max(dim=2).values
else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
# flatten the input
@@ -1899,11 +1899,11 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
self.distribution_output = None
else:
if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput(dim=config.prediction_length * config.num_targets)
self.distribution_output = StudentTOutput(dim=config.num_targets)
elif config.distribution_output == "normal":
self.distribution_output = NormalOutput(dim=config.prediction_length * config.num_targets)
self.distribution_output = NormalOutput(dim=config.num_targets)
elif config.distribution_output == "negative_binomial":
self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length * config.num_targets)
self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
@@ -1974,6 +1974,8 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
if target_values is not None:
if self.distribution_output:
distribution = self.distribution_output.distribution(y_hat)
# y_hat should be a 2-tuple, each with dimension [bs, num_targets]
y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
loss = nll(distribution, target_values)
# take average of the loss
loss = weighted_average(loss)
@@ -1982,6 +1984,7 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
loss = loss(y_hat, target_values)
if not return_dict:
# hidden_states, attentions, mask
outputs = (y_hat,) + model_output[1:-3]
outputs = (loss,) + outputs if loss is not None else outputs
return outputs
@@ -2030,5 +2033,5 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
# get samples: list of [bs x num_targets]
samples = [distribution.sample() for _ in range(num_parallel_samples)]
# samples: [bs x num_samples x num_targets]
samples = torch.stack(samples, dim=1)
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSTOutput(sequences=samples)

View File

@@ -191,11 +191,8 @@ class PatchTSMixerModelTester:
# [bs x context_length x n_vars]
past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
future_values = floats_tensor([self.batch_size, config.prediction_length, self.num_input_channels])
inputs_dict = {
"past_values": past_values,
"future_values": future_values,
}
return inputs_dict
@@ -256,21 +253,25 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# if classification model:
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
if model_class == PatchTSMixerForPrediction:
rng = random.Random(self.model_tester.seed_number)
labels = floats_tensor(
[
self.model_tester.batch_size,
self.model_tester.prediction_length,
self.model_tester.num_input_channels,
],
rng=rng,
)
inputs_dict["future_values"] = labels
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
# inputs_dict["labels"] = labels
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
inputs_dict["target_values"] = labels
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
# inputs_dict["labels"] = labels
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
elif model_class in [PatchTSMixerModel, PatchTSMixerForPretraining]:
inputs_dict.pop("future_values")
inputs_dict["target_values"] = labels
inputs_dict["output_hidden_states"] = True
return inputs_dict
@@ -409,28 +410,37 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names_with_target = [
if model_class == PatchTSMixerForPretraining:
expected_arg_names = [
"past_values",
"observed_mask",
"output_hidden_states",
"return_loss",
]
elif model_class == PatchTSMixerModel:
expected_arg_names = [
"past_values",
"observed_mask",
"output_hidden_states",
]
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
):
expected_arg_names = [
"past_values",
"target_values",
"output_hidden_states",
"return_loss",
]
else:
# PatchTSMixerForPrediction
expected_arg_names = [
"past_values",
"observed_mask",
"future_values",
"output_hidden_states",
"return_loss",
]
expected_arg_names_without_target = [
"past_values",
"observed_mask",
"output_hidden_states",
]
expected_arg_names = expected_arg_names_with_target
if model_class == PatchTSMixerForPretraining:
expected_arg_names = expected_arg_names_without_target + ["return_loss"]
if model_class == PatchTSMixerModel:
expected_arg_names = expected_arg_names_without_target
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
):
expected_arg_names.remove("observed_mask")
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@@ -686,20 +696,27 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
else:
target_output = target_input
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
ground_truth_arg = "future_values"
output_predictions_arg = "prediction_outputs"
elif task == "classification":
mdl = PatchTSMixerForTimeSeriesClassification(config)
target_input = self.__class__.correct_classification_classes
target_output = self.__class__.correct_classification_output
ground_truth_arg = "target_values"
output_predictions_arg = "prediction_outputs"
elif task == "regression":
mdl = PatchTSMixerForRegression(config)
target_input = self.__class__.correct_regression_output
target_output = self.__class__.correct_regression_output
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
ground_truth_arg = "target_values"
output_predictions_arg = "regression_outputs"
elif task == "pretrain":
mdl = PatchTSMixerForPretraining(config)
target_input = None
target_output = self.__class__.correct_pretrain_output
ground_truth_arg = None
output_predictions_arg = "prediction_outputs"
else:
print("invalid task")
@@ -710,15 +727,18 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
else:
output = mdl(
self.__class__.data,
future_values=target_input,
output_hidden_states=output_hidden_states,
**{
ground_truth_arg: target_input,
"output_hidden_states": output_hidden_states,
},
)
if isinstance(output.prediction_outputs, tuple):
for t in output.prediction_outputs:
prediction_outputs = getattr(output, output_predictions_arg)
if isinstance(prediction_outputs, tuple):
for t in prediction_outputs:
self.assertEqual(t.shape, target_output.shape)
else:
self.assertEqual(output.prediction_outputs.shape, target_output.shape)
self.assertEqual(prediction_outputs.shape, target_output.shape)
self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
@@ -980,7 +1000,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_classification_classes,
target_values=self.__class__.correct_classification_classes,
)
self.assertEqual(
output.prediction_outputs.shape,
@@ -994,7 +1014,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_classification_classes,
target_values=self.__class__.correct_classification_classes,
return_dict=False,
)
if isinstance(output, tuple):
@@ -1017,9 +1037,9 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
def test_regression_full(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs.shape,
output.regression_outputs.shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
@@ -1030,13 +1050,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForRegression(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_regression_output,
target_values=self.__class__.correct_regression_output,
return_dict=False,
)
if isinstance(output, tuple):
output = PatchTSMixerForRegressionOutput(*output)
self.assertEqual(
output.prediction_outputs.shape,
output.regression_outputs.shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
@@ -1049,13 +1069,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs[0].shape,
output.regression_outputs[0].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(
output.prediction_outputs[1].shape,
output.regression_outputs[1].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
@@ -1075,13 +1095,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs[0].shape,
output.regression_outputs[0].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(
output.prediction_outputs[1].shape,
output.regression_outputs[1].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)

View File

@@ -367,19 +367,19 @@ class PatchTSTModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
def test_regression_generation(self):
model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression").to(torch_device)
batch = prepare_batch(file="test-batch.pt")
model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device)
batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt")
torch.manual_seed(0)
model.eval()
with torch.no_grad():
outputs = model.generate(past_values=batch["past_values"].to(torch_device))
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.num_targets))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor(
[[0.3228, 0.4320, 0.4591, 0.4066, -0.3461, 0.3094, -0.8426]],
[[-0.08046409], [-0.06570087], [-0.28218266], [-0.20636195], [-0.11787311]],
device=torch_device,
)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, rtol=TOLERANCE))
self.assertTrue(torch.allclose(mean_prediction[-5:], expected_slice, rtol=TOLERANCE))