Conversation pipeline fixes (#26795)
* Adjust length limits and allow naked conversation list inputs * Adjust length limits and allow naked conversation list inputs * Maybe use a slightly more reasonable limit than 1024 * Skip tests for old models that never supported this anyway * Cleanup input docstrings * More docstring cleanup + skip failing TF test * Make fixup
This commit is contained in:
@@ -507,6 +507,10 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
||||
@@ -343,6 +343,10 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
# check that the output for the restored model is the same
|
||||
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
@@ -891,6 +891,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
class T5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
||||
@@ -314,6 +314,10 @@ class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_keras_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
@@ -607,6 +611,10 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
|
||||
@unittest.skip("Does not support conversations.")
|
||||
def test_pipeline_conversational(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user