[Llama] remove prompt and fix prefix finetuning (#25565)
* nit * update * make sure use_default_system_prompt is saved * update checkpointing * consistency * use_default_system_prompt for test
This commit is contained in:
@@ -683,7 +683,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, output_attentions, None)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -692,7 +692,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
add_bos_token=True,
|
add_bos_token=True,
|
||||||
add_eos_token=False,
|
add_eos_token=False,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
|
use_default_system_prompt=True,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
legacy=None,
|
legacy=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -131,6 +132,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
add_eos_token=add_eos_token,
|
add_eos_token=add_eos_token,
|
||||||
sp_model_kwargs=self.sp_model_kwargs,
|
sp_model_kwargs=self.sp_model_kwargs,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
use_default_system_prompt=use_default_system_prompt,
|
||||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
legacy=legacy,
|
legacy=legacy,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -149,8 +151,9 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.add_bos_token = add_bos_token
|
self.add_bos_token = add_bos_token
|
||||||
self.add_eos_token = add_eos_token
|
self.add_eos_token = add_eos_token
|
||||||
self.sp_model = self.get_spm_processor()
|
self.use_default_system_prompt = use_default_system_prompt
|
||||||
|
|
||||||
|
self.sp_model = self.get_spm_processor()
|
||||||
self.unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
self.unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
||||||
|
|
||||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
|
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
|
||||||
@@ -390,16 +393,20 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
`List[int]`:
|
`List[int]`:
|
||||||
Input ids for the conversation.
|
Input ids for the conversation.
|
||||||
"""
|
"""
|
||||||
if len(conversation.past_user_inputs) > 0:
|
if self.use_default_system_prompt:
|
||||||
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
|
if len(conversation.past_user_inputs) > 0:
|
||||||
conversation.past_user_inputs[0] = (
|
if (
|
||||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
not conversation.past_user_inputs[0].startswith(B_SYS)
|
||||||
)
|
or E_SYS not in conversation.past_user_inputs[0]
|
||||||
elif conversation.new_user_input:
|
):
|
||||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
conversation.past_user_inputs[0] = (
|
||||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||||
else:
|
)
|
||||||
raise ValueError("Last message must be from user")
|
elif conversation.new_user_input:
|
||||||
|
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||||
|
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||||
|
else:
|
||||||
|
raise ValueError("Last message must be from user")
|
||||||
|
|
||||||
dialogue = list(conversation.iter_texts())
|
dialogue = list(conversation.iter_texts())
|
||||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
eos_token="</s>",
|
eos_token="</s>",
|
||||||
add_bos_token=True,
|
add_bos_token=True,
|
||||||
add_eos_token=False,
|
add_eos_token=False,
|
||||||
|
use_default_system_prompt=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -119,12 +120,13 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
|
use_default_system_prompt=use_default_system_prompt,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self._add_bos_token = add_bos_token
|
self._add_bos_token = add_bos_token
|
||||||
self._add_eos_token = add_eos_token
|
self._add_eos_token = add_eos_token
|
||||||
self.update_post_processor()
|
self.update_post_processor()
|
||||||
|
self.use_default_system_prompt = use_default_system_prompt
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||||
|
|
||||||
@@ -212,16 +214,20 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
`List[int]`:
|
`List[int]`:
|
||||||
Input ids for the conversation.
|
Input ids for the conversation.
|
||||||
"""
|
"""
|
||||||
if len(conversation.past_user_inputs) > 0:
|
if self.use_default_system_prompt:
|
||||||
if not conversation.past_user_inputs[0].startswith(B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
|
if len(conversation.past_user_inputs) > 0:
|
||||||
conversation.past_user_inputs[0] = (
|
if (
|
||||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
not conversation.past_user_inputs[0].startswith(B_SYS)
|
||||||
)
|
or E_SYS not in conversation.past_user_inputs[0]
|
||||||
elif conversation.new_user_input:
|
):
|
||||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
conversation.past_user_inputs[0] = (
|
||||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||||
else:
|
)
|
||||||
raise ValueError("Last message must be from user")
|
elif conversation.new_user_input:
|
||||||
|
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||||
|
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||||
|
else:
|
||||||
|
raise ValueError("Last message must be from user")
|
||||||
|
|
||||||
dialogue = list(conversation.iter_texts())
|
dialogue = list(conversation.iter_texts())
|
||||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ class ConversationalPipelineTests(unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_integration_torch_conversation_llama2_input_ids(self):
|
def test_integration_torch_conversation_llama2_input_ids(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_default_system_prompt=True)
|
||||||
|
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
"What is so great about #1?",
|
"What is so great about #1?",
|
||||||
|
|||||||
Reference in New Issue
Block a user