Forbid PretrainedConfig from saving generate parameters; Update deprecations in generate-related code 🧹 (#32659)

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-08-23 11:12:53 +01:00
committed by GitHub
parent 22e6f14525
commit 970a16ec7f
53 changed files with 195 additions and 670 deletions

View File

@@ -18,7 +18,7 @@ import shutil
import unittest
from unittest.mock import patch
from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test, require_torch
from transformers.testing_utils import CaptureStd, require_torch
class CLITest(unittest.TestCase):
@@ -33,18 +33,6 @@ class CLITest(unittest.TestCase):
self.assertIn("Platform", cs.out)
self.assertIn("Using distributed or parallel set-up in script?", cs.out)
@is_pt_tf_cross_test
@patch(
"sys.argv", ["fakeprogrampath", "pt-to-tf", "--model-name", "hf-internal-testing/tiny-random-gptj", "--no-pr"]
)
def test_cli_pt_to_tf(self):
import transformers.commands.transformers_cli
shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs
transformers.commands.transformers_cli.main()
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
@require_torch
@patch("sys.argv", ["fakeprogrampath", "download", "hf-internal-testing/tiny-random-gptj", "--cache-dir", "/tmp"])
def test_cli_download(self):