Fix GPT2DoubleHeadsModel to work with model.generate() (#6601)

* Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used

and for GPT2LMHeadModel too

* Update tests to check token_type_ids usage in GPT2 models
This commit is contained in:
LSinev
2020-11-16 16:35:44 +03:00
committed by GitHub
parent 04d8136bde
commit afb50c663a
3 changed files with 121 additions and 2 deletions

View File

@@ -144,6 +144,10 @@ class GenerationMixin:
)
input_ids = input_ids.index_select(0, expanded_return_idx)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
@@ -194,6 +198,11 @@ class GenerationMixin:
else:
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if not is_encoder_decoder:
if "attention_mask" in model_kwargs: