Constrained Beam Search [*With* Disjunctive Decoding] (#15761)
* added classes to get started with constrained beam search * in progress, think i can directly force tokens now but not yet with the round robin * think now i have total control, now need to code the bank selection * technically works as desired, need to optimize and fix design choices leading to undersirable outputs * complete PR #1 without disjunctive decoding * removed incorrect tests * Delete k.txt * Delete test.py * Delete test.sh * revert changes to test scripts * genutils * full implementation with testing, no disjunctive yet * shifted docs * passing all tests realistically ran locally * removing accidentally included print statements * fixed source of error in initial PR test * fixing the get_device() vs device trap * fixed documentation docstrings about constrained_beam_search * fixed tests having failing for Speech2TextModel's floating point inputs * fix cuda long tensor * added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search * deleted accidentally added test halting code with assert False * code reformat * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_generation_utils.py * fixing based on comments on PR * took out the testing code that should but work fails without the beam search moditification ; style changes * fixing comments issues * docstrings for ConstraintListState * typo in PhrsalConstraint docstring * docstrings improvements * finished adding what is sort of an opinionated implementation of disjunctive generation, but it revealed errors in inner beam search logic during testing. * fixed bug found in constrained beam search that used beam_idx that were not global across all the batches * disjunctive constraint working 100% correctly * passing all tests * Accidentally included mlruns * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/generation_beam_constraints.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * complete overhaul of type complexities and other nits * strict type checks in generate() * fixing second round of feedback by narsil * fixed failing generation test because of type check overhaul * generation test fail fix * fixing test fails Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
115
tests/generation/test_generation_beam_constraints.py
Normal file
115
tests/generation/test_generation_beam_constraints.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_constraints import DisjunctiveConstraint
|
||||
|
||||
|
||||
@require_torch
|
||||
class ConstraintTest(unittest.TestCase):
|
||||
def test_input_types(self):
|
||||
# For consistency across different places the DisjunctiveConstraint is called,
|
||||
# dc.token_ids is a list of integers. It is also initialized only by integers.
|
||||
|
||||
cset = [[1, 2, 4], [1, 2, 3, 4]]
|
||||
dc = DisjunctiveConstraint(cset)
|
||||
self.assertTrue(isinstance(dc.token_ids, list))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])])
|
||||
|
||||
def test_check_illegal_input(self):
|
||||
# We can't have constraints that are complete subsets of another. This leads to a preverse
|
||||
# interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint?
|
||||
# It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially
|
||||
# fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm
|
||||
# will necessarily never reach [1,2,3,4], giving users a false sense of control (better to just not allow it).
|
||||
cset = [[1, 2], [1, 2, 3, 4]]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
DisjunctiveConstraint(cset) # fails here
|
||||
|
||||
def test_example_progression(self):
|
||||
cset = [[1, 2, 3], [1, 2, 4]]
|
||||
|
||||
dc = DisjunctiveConstraint(cset)
|
||||
|
||||
stepped, completed, reset = dc.update(1)
|
||||
desired = stepped is True and completed is False and reset is False
|
||||
self.assertTrue(desired)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1])
|
||||
|
||||
stepped, completed, reset = dc.update(2)
|
||||
desired = stepped is True and completed is False and reset is False
|
||||
self.assertTrue(desired)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1, 2])
|
||||
|
||||
stepped, completed, reset = dc.update(3)
|
||||
desired = stepped is True and completed is True and reset is False
|
||||
self.assertTrue(desired)
|
||||
self.assertTrue(dc.completed) # Completed!
|
||||
self.assertTrue(dc.current_seq == [1, 2, 3])
|
||||
|
||||
def test_example_progression_unequal_three_mid_and_reset(self):
|
||||
cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]]
|
||||
|
||||
dc = DisjunctiveConstraint(cset)
|
||||
|
||||
stepped, completed, reset = dc.update(1)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1])
|
||||
|
||||
stepped, completed, reset = dc.update(2)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1, 2])
|
||||
|
||||
stepped, completed, reset = dc.update(4)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.current_seq == [1, 2, 4])
|
||||
|
||||
stepped, completed, reset = dc.update(5)
|
||||
self.assertTrue(dc.completed) # Completed!
|
||||
self.assertTrue(dc.current_seq == [1, 2, 4, 5])
|
||||
|
||||
dc.reset()
|
||||
|
||||
stepped, completed, reset = dc.update(1)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.remaining() == 3)
|
||||
self.assertTrue(dc.current_seq == [1])
|
||||
|
||||
stepped, completed, reset = dc.update(2)
|
||||
self.assertTrue(not dc.completed)
|
||||
self.assertTrue(dc.remaining() == 2)
|
||||
self.assertTrue(dc.current_seq == [1, 2])
|
||||
|
||||
stepped, completed, reset = dc.update(5)
|
||||
self.assertTrue(dc.completed) # Completed!
|
||||
self.assertTrue(dc.remaining() == 0)
|
||||
self.assertTrue(dc.current_seq == [1, 2, 5])
|
||||
@@ -25,7 +25,7 @@ from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
||||
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
|
||||
|
||||
@@ -260,10 +260,10 @@ class ConstrainedBeamSearchTester:
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
|
||||
if constraints is None:
|
||||
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
force_tokens = torch.randint(10, 50, (1, 2))[0].tolist()
|
||||
disjunctive_tokens = torch.randint(10, 50, (2, 2)).tolist()
|
||||
|
||||
constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)]
|
||||
self.constraints = constraints
|
||||
# cannot be randomely generated
|
||||
self.eos_token_id = vocab_size + 1
|
||||
@@ -331,7 +331,13 @@ class ConstrainedBeamSearchTester:
|
||||
):
|
||||
# check too many eos tokens
|
||||
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
|
||||
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
|
||||
stacked_token_ids = []
|
||||
for constraint in self.constraints:
|
||||
token_ids = constraint.token_ids
|
||||
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
|
||||
stacked_token_ids = stacked_token_ids + token_ids
|
||||
|
||||
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
|
||||
fulfill_len = fulfilling_sequence.size(0)
|
||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||
|
||||
@@ -398,7 +404,14 @@ class ConstrainedBeamSearchTester:
|
||||
max_length = self.sequence_length + 1
|
||||
|
||||
# for testing finalize, we do want to have fulfilled constraints
|
||||
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
|
||||
stacked_token_ids = []
|
||||
for constraint in self.constraints:
|
||||
token_ids = constraint.token_ids
|
||||
token_ids = token_ids[0] if isinstance(token_ids[0], list) else token_ids
|
||||
stacked_token_ids = stacked_token_ids + token_ids
|
||||
|
||||
fulfilling_sequence = torch.LongTensor(stacked_token_ids)
|
||||
|
||||
fulfill_len = fulfilling_sequence.size(0)
|
||||
input_ids[:, :fulfill_len] = fulfilling_sequence
|
||||
|
||||
@@ -451,9 +464,17 @@ class ConstrainedBeamSearchTester:
|
||||
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
|
||||
|
||||
# test that the constraint is indeed fulfilled
|
||||
for output in sequences:
|
||||
for constraint in constraints:
|
||||
forced_token_ids = constraint.token_ids
|
||||
for (output, constraint) in [(s, c) for s in sequences for c in constraints]:
|
||||
forced_token_ids = constraint.token_ids
|
||||
if isinstance(forced_token_ids[0], list):
|
||||
# disjunctive case
|
||||
flag = False
|
||||
for token_ids in forced_token_ids:
|
||||
if self._check_sequence_inside_sequence(output, token_ids):
|
||||
flag = True
|
||||
break
|
||||
self.parent.assertEqual(flag, True)
|
||||
else:
|
||||
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
|
||||
|
||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||
@@ -479,18 +500,23 @@ class ConstrainedBeamSearchTester:
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||
# set to same device. we don't care what device.
|
||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
||||
|
||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
||||
if not isinstance(tensor_1, list):
|
||||
tensor_1 = tensor_1.cpu().tolist()
|
||||
if not isinstance(tensor_2, list):
|
||||
tensor_2 = tensor_2.cpu().tolist()
|
||||
|
||||
in_order = len(tensor_1) <= len(tensor_2)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
shorter = tensor_1 if in_order else tensor_2
|
||||
|
||||
flag = False
|
||||
chunk_size = shorter.size(0)
|
||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
||||
chunk_size = len(shorter)
|
||||
for chunk_idx in range(len(longer) - chunk_size + 1):
|
||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||
if torch.equal(subseq, shorter):
|
||||
if subseq == shorter:
|
||||
flag = True
|
||||
break
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ if is_torch_available():
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_beam_constraints import PhrasalConstraint
|
||||
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
@@ -1202,7 +1202,7 @@ class GenerationTesterMixin:
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
@@ -1227,7 +1227,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# check `generate()` and `constrained_beam_search()` are equal for `num_return_sequences`
|
||||
# Sample constraints
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
@@ -1288,7 +1288,7 @@ class GenerationTesterMixin:
|
||||
# otherwise this throws an error for Speech2TextModel since its inputs are floating points
|
||||
min_id = 3
|
||||
max_id = 100
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).type(torch.LongTensor)[0]
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
@@ -1499,18 +1499,23 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||
# set to same device. we don't care what device.
|
||||
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
|
||||
|
||||
in_order = tensor_1.size(0) <= tensor_2.size(0)
|
||||
if not isinstance(tensor_1, list):
|
||||
tensor_1 = tensor_1.cpu().tolist()
|
||||
if not isinstance(tensor_2, list):
|
||||
tensor_2 = tensor_2.cpu().tolist()
|
||||
|
||||
in_order = len(tensor_1) <= len(tensor_2)
|
||||
longer = tensor_2 if in_order else tensor_1
|
||||
shorter = tensor_1 if in_order else tensor_2
|
||||
|
||||
flag = False
|
||||
chunk_size = shorter.size(0)
|
||||
for chunk_idx in range(longer.size(0) - chunk_size + 1):
|
||||
chunk_size = len(shorter)
|
||||
for chunk_idx in range(len(longer) - chunk_size + 1):
|
||||
subseq = longer[chunk_idx : chunk_idx + chunk_size]
|
||||
if torch.equal(subseq, shorter):
|
||||
if subseq == shorter:
|
||||
flag = True
|
||||
break
|
||||
|
||||
@@ -2315,8 +2320,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
|
||||
force_tokens = tokenizer.encode(" scared", return_tensors="pt").to(torch_device)[0]
|
||||
force_tokens_2 = tokenizer.encode(" big weapons", return_tensors="pt").to(torch_device)[0]
|
||||
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
@@ -2346,6 +2351,105 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
|
||||
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
flexible_phrases = tokenizer(
|
||||
["scream", "screams", "screaming", "screamed"], add_prefix_space=True, add_special_tokens=False
|
||||
).input_ids
|
||||
|
||||
constraints = [
|
||||
PhrasalConstraint(force_phrase),
|
||||
DisjunctiveConstraint(flexible_phrases),
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers", "The child"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
constraints=constraints,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
# max_length=20,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
|
||||
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
|
||||
force_word = "scared"
|
||||
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||
|
||||
force_words_ids = [
|
||||
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
|
||||
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
|
||||
]
|
||||
|
||||
starting_text = ["The soldiers", "The child"]
|
||||
|
||||
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The soldiers, who were all scared and screaming at each other as they tried to get out of the",
|
||||
"The child was taken to a local hospital where she screamed and scared for her life, police said.",
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_translation_mixin(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
force_words = ["sind"]
|
||||
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_integration(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
@@ -2389,3 +2493,43 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(outputs, ["Wie alter sind Sie?"])
|
||||
|
||||
def test_constrained_beam_search_mixin_type_checks(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
|
||||
encoder_input_str = "translate English to German: How old are you?"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
force_words = ["sind"]
|
||||
force_words_ids = tokenizer(force_words, return_tensors="pt").input_ids
|
||||
model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
force_words = ["sind"]
|
||||
force_words_ids = [tokenizer(force_words, return_tensors="pt").input_ids]
|
||||
model.generate(
|
||||
input_ids,
|
||||
force_words_ids=force_words_ids,
|
||||
num_beams=10,
|
||||
num_return_sequences=1,
|
||||
no_repeat_ngram_size=1,
|
||||
remove_invalid_values=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[-1]])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
Reference in New Issue
Block a user