PatchtTST and PatchTSMixer fixes (#28083)

* 🐛 fix .max bug

* remove prediction_length from regression output dimensions

* fix parameter names, fix output names, update tests

* ensure shape for PatchTST

* ensure output shape for PatchTSMixer

* update model, batch, and expected for regression distribution test

* update test expected

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtst/test_modeling_patchtst.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* standardize on patch_length

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Make arguments more explicit

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

* adjust prepared inputs

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>

---------

Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: Wesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Wesley Gifford
2024-01-29 05:09:26 -05:00
committed by GitHub
parent 3a08cc485f
commit f72c7c22d9
5 changed files with 115 additions and 89 deletions

View File

@@ -367,19 +367,19 @@ class PatchTSTModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
def test_regression_generation(self):
model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression").to(torch_device)
batch = prepare_batch(file="test-batch.pt")
model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device)
batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt")
torch.manual_seed(0)
model.eval()
with torch.no_grad():
outputs = model.generate(past_values=batch["past_values"].to(torch_device))
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.num_targets))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor(
[[0.3228, 0.4320, 0.4591, 0.4066, -0.3461, 0.3094, -0.8426]],
[[-0.08046409], [-0.06570087], [-0.28218266], [-0.20636195], [-0.11787311]],
device=torch_device,
)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, rtol=TOLERANCE))
self.assertTrue(torch.allclose(mean_prediction[-5:], expected_slice, rtol=TOLERANCE))