The add_space_before_punct_symbol is only for TransfoXL (#5549)
This commit is contained in:
@@ -214,8 +214,14 @@ def main():
|
|||||||
if requires_preprocessing:
|
if requires_preprocessing:
|
||||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||||
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
|
|
||||||
|
if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
|
||||||
|
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
|
||||||
|
else:
|
||||||
|
tokenizer_kwargs = {}
|
||||||
|
|
||||||
encoded_prompt = tokenizer.encode(
|
encoded_prompt = tokenizer.encode(
|
||||||
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
|
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||||
|
|||||||
Reference in New Issue
Block a user