From bc3e20dcf08a03e22a0e4a42a0ce5a8ec94180e5 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 18 Aug 2023 13:39:23 +0200 Subject: [PATCH] [`Llama`] remove prompt and fix prefix finetuning (#25565) * nit * update * make sure use_default_system_prompt is saved * update checkpointing * consistency * use_default_system_prompt for test --- .../models/llama/modeling_llama.py | 3 +- .../models/llama/tokenization_llama.py | 29 ++++++++++++------- .../models/llama/tokenization_llama_fast.py | 28 +++++++++++------- .../test_pipelines_conversational.py | 2 +- 4 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 84d63fae15..309c3ef1de 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -683,7 +683,7 @@ class LlamaModel(LlamaPreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -692,7 +692,6 @@ class LlamaModel(LlamaPreTrainedModel): hidden_states, attention_mask, position_ids, - None, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index d47841b8b2..65ae8e2bd6 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -113,6 +113,7 @@ class LlamaTokenizer(PreTrainedTokenizer): add_bos_token=True, add_eos_token=False, clean_up_tokenization_spaces=False, + use_default_system_prompt=True, spaces_between_special_tokens=False, legacy=None, **kwargs, @@ -131,6 +132,7 @@ class LlamaTokenizer(PreTrainedTokenizer): add_eos_token=add_eos_token, sp_model_kwargs=self.sp_model_kwargs, clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, spaces_between_special_tokens=spaces_between_special_tokens, legacy=legacy, **kwargs, @@ -149,8 +151,9 @@ class LlamaTokenizer(PreTrainedTokenizer): self.vocab_file = vocab_file self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token - self.sp_model = self.get_spm_processor() + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor() self.unk_token_length = len(self.sp_model.encode(str(self.unk_token))) # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor @@ -390,16 +393,20 @@ class LlamaTokenizer(PreTrainedTokenizer): `List[int]`: Input ids for the conversation. """ - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif conversation.new_user_input: - if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: - conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input - else: - raise ValueError("Last message must be from user") + if self.use_default_system_prompt: + if len(conversation.past_user_inputs) > 0: + if ( + not conversation.past_user_inputs[0].startswith(B_SYS) + or E_SYS not in conversation.past_user_inputs[0] + ): + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py index 533c6adf7b..785869ea66 100644 --- a/src/transformers/models/llama/tokenization_llama_fast.py +++ b/src/transformers/models/llama/tokenization_llama_fast.py @@ -110,6 +110,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): eos_token="", add_bos_token=True, add_eos_token=False, + use_default_system_prompt=True, **kwargs, ): super().__init__( @@ -119,12 +120,13 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, + use_default_system_prompt=use_default_system_prompt, **kwargs, ) self._add_bos_token = add_bos_token self._add_eos_token = add_eos_token self.update_post_processor() - + self.use_default_system_prompt = use_default_system_prompt self.vocab_file = vocab_file self.can_save_slow_tokenizer = False if not self.vocab_file else True @@ -212,16 +214,20 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): `List[int]`: Input ids for the conversation. """ - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] - ) - elif conversation.new_user_input: - if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: - conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input - else: - raise ValueError("Last message must be from user") + if self.use_default_system_prompt: + if len(conversation.past_user_inputs) > 0: + if ( + not conversation.past_user_inputs[0].startswith(B_SYS) + or E_SYS not in conversation.past_user_inputs[0] + ): + conversation.past_user_inputs[0] = ( + B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0] + ) + elif conversation.new_user_input: + if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input: + conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input + else: + raise ValueError("Last message must be from user") dialogue = list(conversation.iter_texts()) if not all([is_user for is_user, msg in dialogue[::2]]) or not all( diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py index 2406357b89..efb2215f49 100644 --- a/tests/pipelines/test_pipelines_conversational.py +++ b/tests/pipelines/test_pipelines_conversational.py @@ -220,7 +220,7 @@ class ConversationalPipelineTests(unittest.TestCase): @require_torch @slow def test_integration_torch_conversation_llama2_input_ids(self): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_default_system_prompt=True) conversation = Conversation( "What is so great about #1?",