[generate] move max time tests (#35962)
* move max time tests to their right place * move test to the right place
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import datetime
|
||||
import unittest
|
||||
|
||||
from transformers import GPTJConfig, is_torch_available
|
||||
@@ -546,47 +545,6 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
|
||||
all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
|
||||
) # token_type_ids should change output
|
||||
|
||||
@slow
|
||||
def test_gptj_sample_max_time(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("anton-l/gpt-j-tiny-random")
|
||||
model = GPTJForCausalLM.from_pretrained("anton-l/gpt-j-tiny-random")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
|
||||
input_ids = tokenized.input_ids.to(torch_device)
|
||||
|
||||
MAX_TIME = 0.5
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
@tooslow
|
||||
def test_contrastive_search_gptj(self):
|
||||
article = (
|
||||
|
||||
Reference in New Issue
Block a user