Fix torch.fx issue related to the new loss_kwargs keyword argument (#34380)

* Fix FX

* Unskip tests
This commit is contained in:
Michael Benayoun
2024-10-24 18:34:28 +02:00
committed by GitHub
parent d9989e0b9a
commit 1c5918d910
6 changed files with 1 additions and 6 deletions

View File

@@ -304,7 +304,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()