[gen utils] missing else case (#6980)

* [gen utils] missing else case

1. `else` is missing - I hit that case while porting a model. Probably needs to assert there?
2. also the comment on top seems to be outdated (just vocab_size is being set there)

* typo
This commit is contained in:
Stas Bekman
2020-09-07 04:28:06 -07:00
committed by GitHub
parent f7e80721eb
commit 848fbe1e35

View File

@@ -358,7 +358,7 @@ class GenerationMixin:
)
pad_token_id = eos_token_id
# current position and vocab size
# vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
@@ -367,6 +367,8 @@ class GenerationMixin:
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
else:
raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
# set effective batch size and effective batch multiplier according to do_sample
if do_sample: