TF: XLA beam search + most generation-compatible models are now also XLA-generate-compatible (#17857)
* working beam search 🎉
* XLA generation compatible with ALL classes
* add xla generation slow test
This commit is contained in:
@@ -152,23 +152,6 @@ class TFBartModelTester:
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||
|
||||
def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args):
|
||||
config.eos_token_id = None # Generate until max length
|
||||
config.max_length = 10
|
||||
config.do_sample = False
|
||||
config.num_beams = 1
|
||||
model = TFBartForConditionalGeneration(config=config)
|
||||
|
||||
# make sure there are no pad tokens in prompt
|
||||
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
|
||||
|
||||
generated = model.generate(input_ids)
|
||||
|
||||
generate_xla = tf.function(model.generate, jit_compile=True)
|
||||
generated_xla = generate_xla(input_ids)
|
||||
|
||||
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
|
||||
|
||||
|
||||
def prepare_bart_inputs_dict(
|
||||
config,
|
||||
@@ -310,10 +293,6 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_bart_model_xla_generate_fast(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.create_and_check_bart_xla_generate_fast(config, inputs["input_ids"])
|
||||
|
||||
def test_saved_model_creation(self):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
@@ -703,10 +682,8 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
|
||||
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
assert result == EXPECTED
|
||||
|
||||
def test_xsum_1_1_xla_greedy_generation(self):
|
||||
# TODO (Joao): this is temporary test, while XLA beam search is not operational. Move the XLA==non-XLA
|
||||
# comparisons to the other tests after enabling XLA beam search.
|
||||
# Note -- `no_repeat_ngram_size` has to be disabled, since it is not compatible with XLA
|
||||
def test_xsum_1_1_xla_generation(self):
|
||||
# same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled
|
||||
model = self.xsum_1_1_model
|
||||
assert model.model.decoder.embed_tokens._layer == model.model.shared
|
||||
ARTICLE = (
|
||||
@@ -748,15 +725,16 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
EXPECTED = (
|
||||
" The International Criminal Court (ICC) has announced that it is to be investigated by the International"
|
||||
" Criminal Court (ICC) over claims that the Palestinian genocide."
|
||||
" Criminal Court (ICC) over allegations of war crimes."
|
||||
)
|
||||
|
||||
dct = self.tok(ARTICLE, return_tensors="tf")
|
||||
generated_ids = model.generate(**dct, num_beams=1, no_repeat_ngram_size=0)
|
||||
generated_ids = model.generate(**dct, num_beams=4, no_repeat_ngram_size=0)
|
||||
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
assert result == EXPECTED
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
generated_ids = xla_generate(**dct, num_beams=1, no_repeat_ngram_size=0)
|
||||
generated_ids = xla_generate(**dct, num_beams=4, no_repeat_ngram_size=0)
|
||||
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
assert result == EXPECTED
|
||||
|
||||
|
||||
Reference in New Issue
Block a user