TF: BART compatible with XLA generation (#17479)

* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
Joao Gante
2022-06-20 11:07:46 +01:00
committed by GitHub
parent 6589e510fa
commit 132402d752
18 changed files with 421 additions and 86 deletions

View File

@@ -125,8 +125,22 @@ class TFBartModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
decoder_position_ids = tf.cast(tf.cumsum(next_attention_mask, axis=1, exclusive=True), dtype=tf.int32)
output_from_no_past = model(
next_input_ids, attention_mask=next_attention_mask, position_ids=decoder_position_ids
)
output_from_no_past = output_from_no_past[0]
decoder_position_ids = (
tf.cast(tf.cumsum(next_attn_mask, axis=1, exclusive=True), dtype=tf.int32) + past_key_values[0][0].shape[2]
)
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)
output_from_past = output_from_past[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
@@ -138,6 +152,23 @@ class TFBartModelTester:
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_bart_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFBartForConditionalGeneration(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def prepare_bart_inputs_dict(
config,
@@ -279,25 +310,15 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
models_equal = False
self.assertTrue(models_equal)
def test_bart_model_xla_generate_fast(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_bart_xla_generate_fast(config, inputs["input_ids"])
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if tf.debugging.assert_near(a, b, atol=atol):
return True
raise
except Exception:
if len(prefix) > 0:
prefix = f"{prefix}: "
raise AssertionError(f"{prefix}{a} != {b}")
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)
@@ -682,6 +703,63 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
def test_xsum_1_1_xla_greedy_generation(self):
# TODO (Joao): this is temporary test, while XLA beam search is not operational. Move the XLA==non-XLA
# comparisons to the other tests after enabling XLA beam search.
# Note -- `no_repeat_ngram_size` has to be disabled, since it is not compatible with XLA
model = self.xsum_1_1_model
assert model.model.decoder.embed_tokens._layer == model.model.shared
ARTICLE = (
"The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
" formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
" The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
" situation in Palestinian territories, paving the way for possible war crimes investigations against"
" Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
" the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
" body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
" Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
" acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
" immediately end their pressure, and countries that support universal acceptance of the court's treaty"
' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
" group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
' decision to join a treaty to which over 100 countries around the world are members." In January, when'
" the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
" disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
" it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
" court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
" between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
" will include alleged war crimes committed since June. The International Criminal Court was set up in"
" 2002 to prosecute genocide, crimes against humanity and war crimes."
)
EXPECTED = (
" The International Criminal Court (ICC) has announced that it is to be investigated by the International"
" Criminal Court (ICC) over claims that the Palestinian genocide."
)
dct = self.tok(ARTICLE, return_tensors="tf")
generated_ids = model.generate(**dct, num_beams=1, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = xla_generate(**dct, num_beams=1, no_repeat_ngram_size=0)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
assert result == EXPECTED
def test_xsum_1_1_batch_generation(self):
batch = self.tok(
[

View File

@@ -295,7 +295,7 @@ class TFGPT2ModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)

View File

@@ -228,7 +228,7 @@ class TFT5ModelTester:
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_t5_xla_generate_fast(self, config, input_ids, *args):
config.eos_token_id = None
config.eos_token_id = None # Generate until max length
config.max_length = 10
config.do_sample = False
config.num_beams = 1