Fix GPT2DoubleHeadsModel to work with model.generate() (#6601)
* Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used and for GPT2LMHeadModel too * Update tests to check token_type_ids usage in GPT2 models
This commit is contained in:
@@ -144,6 +144,10 @@ class GenerationMixin:
|
||||
)
|
||||
input_ids = input_ids.index_select(0, expanded_return_idx)
|
||||
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
|
||||
|
||||
if attention_mask is not None:
|
||||
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
||||
|
||||
@@ -194,6 +198,11 @@ class GenerationMixin:
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if not is_encoder_decoder:
|
||||
if "attention_mask" in model_kwargs:
|
||||
|
||||
Reference in New Issue
Block a user