* create test for #38916 (custom generate from local dir with imports)
This commit is contained in:
committed by
GitHub
parent
25c44d4b68
commit
3abeaba7e5
@@ -22,6 +22,7 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -4995,6 +4996,27 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model.generate(**model_inputs, custom_generate="transformers-community/custom_generate_example")
|
model.generate(**model_inputs, custom_generate="transformers-community/custom_generate_example")
|
||||||
|
|
||||||
|
def test_custom_generate_local_directory(self):
|
||||||
|
"""Tests that custom_generate works with local directories containing importable relative modules"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
custom_generate_dir = Path(tmp_dir) / "custom_generate"
|
||||||
|
custom_generate_dir.mkdir()
|
||||||
|
with open(custom_generate_dir / "generate.py", "w") as f:
|
||||||
|
f.write("from .helper import ret_success\ndef generate(*args, **kwargs):\n return ret_success()\n")
|
||||||
|
with open(custom_generate_dir / "helper.py", "w") as f:
|
||||||
|
f.write('def ret_success():\n return "success"\n')
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
|
||||||
|
value = model.generate(
|
||||||
|
**model_inputs,
|
||||||
|
custom_generate=str(tmp_dir),
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
assert value == "success"
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TokenHealingTestCase(unittest.TestCase):
|
class TokenHealingTestCase(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user