Generation: fix handling of special tokens (#31254)
* fix special tokens in generatioon * fix test * add warning * fix the check * warn once * fix
This commit is contained in:
committed by
GitHub
parent
7729b77478
commit
5fabd1e83b
@@ -1436,23 +1436,6 @@ class GenerationMixin:
|
|||||||
self._cache.reset()
|
self._cache.reset()
|
||||||
return self._cache
|
return self._cache
|
||||||
|
|
||||||
def _get_decoder_start_token_id(
|
|
||||||
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
|
|
||||||
) -> int:
|
|
||||||
decoder_start_token_id = (
|
|
||||||
decoder_start_token_id
|
|
||||||
if decoder_start_token_id is not None
|
|
||||||
else self.generation_config.decoder_start_token_id
|
|
||||||
)
|
|
||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
|
||||||
|
|
||||||
if decoder_start_token_id is not None:
|
|
||||||
return decoder_start_token_id
|
|
||||||
elif bos_token_id is not None:
|
|
||||||
return bos_token_id
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
|
|
||||||
def _supports_default_dynamic_cache(self) -> bool:
|
def _supports_default_dynamic_cache(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
|
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
|
||||||
@@ -1478,25 +1461,32 @@ class GenerationMixin:
|
|||||||
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Convert special tokens to tensors (if they exist)
|
# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
|
||||||
def _tensor_or_none(token, device=None):
|
def _tensor_or_none(token_kwargs, token_self, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
|
token = token_kwargs if token_kwargs is not None else token_self
|
||||||
if token is None or isinstance(token, torch.Tensor):
|
if token is None or isinstance(token, torch.Tensor):
|
||||||
return token
|
return token
|
||||||
return torch.tensor(token, device=device, dtype=torch.long)
|
return torch.tensor(token, device=device, dtype=torch.long)
|
||||||
|
|
||||||
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
|
bos_token_id = _tensor_or_none(
|
||||||
if self.config.is_encoder_decoder:
|
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
|
||||||
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
|
)
|
||||||
generation_config.decoder_start_token_id, generation_config.bos_token_id
|
eos_token_id = _tensor_or_none(
|
||||||
)
|
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
|
||||||
|
)
|
||||||
|
pad_token_id = _tensor_or_none(
|
||||||
|
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
|
||||||
|
)
|
||||||
|
decoder_start_token_id = _tensor_or_none(
|
||||||
|
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
|
||||||
|
)
|
||||||
|
|
||||||
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
|
# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
|
||||||
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
|
if self.config.is_encoder_decoder:
|
||||||
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
|
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||||
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
|
||||||
|
|
||||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||||
if eos_token_id is not None and eos_token_id.ndim == 0:
|
if eos_token_id is not None and eos_token_id.ndim == 0:
|
||||||
@@ -1512,6 +1502,15 @@ class GenerationMixin:
|
|||||||
pad_token_id = eos_token_id[0]
|
pad_token_id = eos_token_id[0]
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
|
||||||
|
|
||||||
|
# we can't infer attn mask if pad token is set to be eos token in model's generation config
|
||||||
|
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
|
||||||
|
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||||
|
logger.warning_once(
|
||||||
|
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
|
||||||
|
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
|
||||||
|
"to obtain reliable results."
|
||||||
|
)
|
||||||
|
|
||||||
# Sanity checks/warnings
|
# Sanity checks/warnings
|
||||||
if self.config.is_encoder_decoder and decoder_start_token_id is None:
|
if self.config.is_encoder_decoder and decoder_start_token_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class GenerationIntegrationTestsMixin:
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
||||||
|
model.generation_config.eos_token_id = None
|
||||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||||
if is_pt:
|
if is_pt:
|
||||||
model = model.to(torch_device)
|
model = model.to(torch_device)
|
||||||
@@ -170,7 +171,6 @@ class GenerationIntegrationTestsMixin:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_new_tokens=5,
|
max_new_tokens=5,
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
eos_token_id=None,
|
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
)
|
)
|
||||||
@@ -197,6 +197,7 @@ class GenerationIntegrationTestsMixin:
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
||||||
|
model.generation_config.eos_token_id = None
|
||||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||||
if is_pt:
|
if is_pt:
|
||||||
model = model.to(torch_device)
|
model = model.to(torch_device)
|
||||||
@@ -206,7 +207,6 @@ class GenerationIntegrationTestsMixin:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_new_tokens=5,
|
max_new_tokens=5,
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
eos_token_id=None,
|
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user