Add TF whisper (#19378)
* simplify loop * add featur extractor * add model * start conversion * add dropout * initial commit of test files * copnversion for all models * update processor for correct padding * update feature extraction * update integration test logits match * fmnt: off for the logits * on the fly mel bank * small nit * update test * update tokenizer * nit feature extraction * update * update tokenizer test * adds logit processor and update tokenizer to get supress tokens * style * clean convert * revert to original modeling tf utils * Update * update * nit * clean convert file * update tests and nits * quality * slow generation test * ffn_dim to allow customization * update readme * add to toctreee * start fixing integration tests * update tests and code * fix feature extractor * fix config tests common * update code to fix tests * fix feature exctractor * nit feature extraction * update test for new feature extractor * style * add absrtact * large logits wioth custom decoder input ids * wraap around is otrch available * fix feature extractor * correct logits for whisper small.en * nit * fix encoder_attentino_mask * some fixes * remove unnecessary inputs * nits * add normalizer file * update etst tokenization * fix attention mask not defined * fix generate * remove uncoder attention mask useless * update test modeling whisper * update condfig to add second non supress tokens * nits on feature exrtactor * nit for test tokenizers * update etsts * update tests * update tokenization test * fixup * invalidated hf token. Clean convert openai to whisper * fix logit tests * fixup * Add model to README * Fix doc tests * clean merge * revert toc_tree changes * remove useless LogitProcessor * Update whisper .mdx * update config file doc * update configuration docstring * update test tokenization * update test tokenization * update tokenization whisper Added copied from where needed * update feature extraction * nit test name * style * quality * remove get suppress tokens and update non_speech tokens global variables * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * clean modeling whisper and test Removed the attention mask arguments that are deprecated * fix large test * Add multilingual audio test, and translate test * style * fix larg multilingual test * nits * add copied from for attention layer * remove attention masks in doc * add english normalizer * Update docs/source/en/model_doc/whisper.mdx Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update tokenization test * remove copied from in whisper attention : no bias in k_proj only * wrap around dependencies in english normalizer * style * correct import generation logits * for now, wrap feature extractor with torch * remove torch depencies for feature extraction and style * Update src/transformers/models/whisper/convert_openai_whisper_to_tfms.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/whisper.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fixup * nit * update logitds * style * nit * nits and fix final tests * add `is_more_itertools_available` to utils * quality * add begin supress tokens, supress tokens to generate args and config * clean supressTokensLogitProcessor in generation logits * Nit naming * add supressTokensAtBegin * udpate tests, supress tokens to None or correct values * nit and style * update RAG to fit test and generate_logit * add copy pasted statment on english normalizer * add arguments to config_common_kwargs * Update src/transformers/generation_utils.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/generation_logits_process.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * revert changes based on reviews * update doc and nits * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * more nits * last nits * update test configuration common * add BART name in decoder attention mask documentation * Update src/transformers/models/whisper/modeling_whisper.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * style * nit * nit * add english.json file to git * nits on documentation * nit * nits * last styling * add main toctree file * remove sentence piece dependency * clean init file * fix tokenizer that has no dependencies on sentencepiece * update whisper init file, nit * remove english.json file * add get decoder prompt id * All weights loading * Remove hanging pdb * Fixup and tidy up * Use same copied from as PT model * Remove whitespace changes * Remove torch references * Tie embeddings * Remove logits processor input to generate * Update logit values * revert changes and add forced logit processor * nit * clean normalizer * remove protected * Add logit processors and update generation code & tests * Some tidy up * Update docstring * update * update based on review * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update to reflect changes on the PT model branch * Tidy up * Remove extra whitespace * Fix test - make input ids small enough we can append * Include upstream changes on main * PR comments - add batch tests, remove comments & defaults * Fix model output imports * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation_tf_logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/whisper/test_modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update docstring example * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Remove changes to adjust_logits_during_generation function * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Tidy up imports that don't require TF * Update tests - skip and no more skip * Update tests/generation/test_generation_tf_logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Add training flags * Add (skipped) XLA generation tests * Add embedding correctness test * Add constant ids for generation tests * Make logits finding a bit tidier * Remove unused args * xla generation enabled * Don't skip XLA tests anymore * Fix tests - add position ids to expected signature and update rag generation * Undo method reorder * Remove added whitespace * Remove copy-paste gradient checkopint ref * Remove * Trigger CI - (issue with refs when pulling) Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: NielsRogge <niels.rogge1@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -29,11 +29,14 @@ if is_tf_available():
|
||||
from transformers.generation_tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
@@ -331,6 +334,86 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_suppress_tokens_at_begin_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
begin_suppress_tokens = [1, 2, 3]
|
||||
begin_index = 5
|
||||
|
||||
logits_processor = TFSuppressTokensAtBeginLogitsProcessor(
|
||||
begin_suppress_tokens=begin_suppress_tokens, begin_index=begin_index
|
||||
)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# Check that no scores are suppressed if begin_index is not reached
|
||||
cur_len = 4
|
||||
input_ids = tf.convert_to_tensor([[11, 17, 15, 8], [14, 0, 19, 5], [13, 11, 18, 19], [11, 12, 16, 15]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
# Check that scores are suppressed if begin_index is reached
|
||||
cur_len = 5
|
||||
input_ids = tf.convert_to_tensor([[5, 5, 5, 0, 17], [18, 1, 9, 14, 17], [18, 6, 8, 15, 19], [8, 12, 17, 1, 2]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, begin_suppress_tokens, axis=1))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_suppress_tokens_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
suppress_tokens = [1, 3, 5]
|
||||
keep_tokens = [i for i in range(vocab_size) if i not in suppress_tokens]
|
||||
|
||||
logits_processor = TFSuppressTokensLogitsProcessor(suppress_tokens=suppress_tokens)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# Check that suppress_tokens are suppressed and others are not
|
||||
cur_len = 5
|
||||
input_ids = tf.convert_to_tensor([[0, 10, 19, 6, 3], [17, 4, 8, 17, 2], [7, 1, 11, 6, 15], [5, 8, 13, 16, 0]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, suppress_tokens, axis=1))))
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(tf.gather(scores, keep_tokens, axis=1))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_force_tokens_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
force_token_map = {1: 2, 3: 2}
|
||||
|
||||
logits_processor = TFForceTokensLogitsProcessor(force_token_map=force_token_map)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that if the cur_len is contained in the force_token_map, the logits are the same
|
||||
# for all tokens except the one the force_token_map points to
|
||||
cur_len = 1
|
||||
input_ids = tf.convert_to_tensor([[11], [7], [5], [15]])
|
||||
ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
tf.debugging.assert_near(tf.gather(scores, [force_token_map[cur_len]], axis=1), 0.0)
|
||||
|
||||
non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))),
|
||||
)
|
||||
|
||||
# check that if the cur_len is not contained in the force_token_map, the logits are not modified
|
||||
cur_len = 2
|
||||
input_ids = tf.convert_to_tensor([[2, 19], [19, 15], [4, 9], [7, 6]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_processor_list(self, use_xla):
|
||||
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
|
||||
|
||||
Reference in New Issue
Block a user