Set missing seq_length variable when using inputs_embeds with ALBERT & Remove code duplication (#13152)

* Set seq_length variable when using inputs_embeds

* remove code duplication
This commit is contained in:
Jongheon Kim
2021-08-31 19:51:25 +09:00
committed by GitHub
parent 180c6de6a6
commit ef8d6f2b4a
14 changed files with 14 additions and 27 deletions

View File

@@ -846,13 +846,12 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length