From 9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0 Mon Sep 17 00:00:00 2001 From: Eli Simhayev Date: Mon, 3 Apr 2023 20:07:21 +0700 Subject: [PATCH] [Time-Series] fix past_observed_mask type (#22076) added > 0.5 to `past_observed_mask` --- tests/models/informer/test_modeling_informer.py | 2 +- .../test_modeling_time_series_transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index 91e6fb74f6..271f997bee 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -117,7 +117,7 @@ class InformerModelTester: past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features]) past_values = floats_tensor([self.batch_size, _past_length]) - past_observed_mask = floats_tensor([self.batch_size, _past_length]) + past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5 # decoder inputs future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features]) diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index 7f14f29de0..65834dac42 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -114,7 +114,7 @@ class TimeSeriesTransformerModelTester: past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features]) past_values = floats_tensor([self.batch_size, _past_length]) - past_observed_mask = floats_tensor([self.batch_size, _past_length]) + past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5 # decoder inputs future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])