[gen utils] missing else case (#6980)
* [gen utils] missing else case 1. `else` is missing - I hit that case while porting a model. Probably needs to assert there? 2. also the comment on top seems to be outdated (just vocab_size is being set there) * typo
This commit is contained in:
@@ -358,7 +358,7 @@ class GenerationMixin:
|
||||
)
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# current position and vocab size
|
||||
# vocab size
|
||||
if hasattr(self.config, "vocab_size"):
|
||||
vocab_size = self.config.vocab_size
|
||||
elif (
|
||||
@@ -367,6 +367,8 @@ class GenerationMixin:
|
||||
and hasattr(self.config.decoder, "vocab_size")
|
||||
):
|
||||
vocab_size = self.config.decoder.vocab_size
|
||||
else:
|
||||
raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined")
|
||||
|
||||
# set effective batch size and effective batch multiplier according to do_sample
|
||||
if do_sample:
|
||||
|
||||
Reference in New Issue
Block a user